diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 442e6e4009f6..06db092d6fc8 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -37,4 +37,4 @@ If there are user-facing changes then we may require documentation to be updated \ No newline at end of file +--> diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index 15fbbfca0f65..19af21ec910b 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -30,7 +30,7 @@ jobs: - name: Checkout uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Audit licenses @@ -41,9 +41,9 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-node@v3 + - uses: actions/setup-node@v4 with: - node-version: "14" + node-version: "20" - name: Prettier check run: | # if you encounter error, rerun the command below and commit the changes diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/dev_pr.yml index 85aabc188934..77b257743331 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/dev_pr.yml @@ -46,7 +46,7 @@ jobs: github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'synchronize') - uses: actions/labeler@v4.3.0 + uses: actions/labeler@v5.0.0 with: repo-token: ${{ secrets.GITHUB_TOKEN }} configuration-path: .github/workflows/dev_pr/labeler.yml diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/dev_pr/labeler.yml index e84cf5efb1d8..34a37948785b 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/dev_pr/labeler.yml @@ -16,35 +16,37 @@ # under the License. development-process: - - dev/**.* - - .github/**.* - - ci/**.* - - .asf.yaml +- changed-files: + - any-glob-to-any-file: ['dev/**.*', '.github/**.*', 'ci/**.*', '.asf.yaml'] documentation: - - docs/**.* - - README.md - - ./**/README.md - - DEVELOPERS.md - - datafusion/docs/**.* +- changed-files: + - any-glob-to-any-file: ['docs/**.*', 'README.md', './**/README.md', 'DEVELOPERS.md', 'datafusion/docs/**.*'] sql: - - datafusion/sql/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/sql/**/*'] logical-expr: - - datafusion/expr/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/expr/**/*'] physical-expr: - - datafusion/physical-expr/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/physical-expr/**/*'] optimizer: - - datafusion/optimizer/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/optimizer/**/*'] core: - - datafusion/core/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/core/**/*'] substrait: - - datafusion/substrait/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/substrait/**/*'] sqllogictest: - - datafusion/sqllogictest/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/sqllogictest/**/*'] diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 14b2038e8794..ab6a615ab60b 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -24,7 +24,7 @@ jobs: path: asf-site - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index fa5c56b43e03..ae6c1ee56129 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -47,21 +47,30 @@ jobs: image: amd64/rust steps: - uses: actions/checkout@v4 - - name: Cache Cargo - uses: actions/cache@v3 - with: - # these represent dependencies downloaded by cargo - # and thus do not depend on the OS, arch nor rust version. - path: /github/home/.cargo - key: cargo-cache- - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable + - name: Cache Cargo + uses: actions/cache@v3 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + ./target/ + ./datafusion-cli/target/ + # this key equals the ones on `linux-build-lib` for re-use + key: cargo-cache-benchmark-${{ hashFiles('datafusion/**/Cargo.toml', 'benchmarks/Cargo.toml', 'datafusion-cli/Cargo.toml') }} + - name: Check workspace without default features run: cargo check --no-default-features -p datafusion + - name: Check datafusion-common without default features + run: cargo check --tests --no-default-features -p datafusion-common + - name: Check workspace in debug mode run: cargo check @@ -84,18 +93,20 @@ jobs: - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable - name: Run tests (excluding doctests) run: cargo test --lib --tests --bins --features avro,json,backtrace + env: + # do not produce debug symbols to keep memory usage down + # hardcoding other profile params to avoid profile override values + # More on Cargo profiles https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings + RUSTFLAGS: "-C debuginfo=0 -C opt-level=0 -C incremental=false -C codegen-units=256" + RUST_BACKTRACE: "1" + # avoid rust stack overflows on tpc-ds tests + RUST_MIN_STACK: "3000000" - name: Verify Working Directory Clean run: git diff --exit-code @@ -109,12 +120,6 @@ jobs: - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -145,19 +150,7 @@ jobs: # test datafusion-sql examples cargo run --example sql # test datafusion-examples - cargo run --example avro_sql --features=datafusion/avro - cargo run --example csv_sql - cargo run --example custom_datasource - cargo run --example dataframe - cargo run --example dataframe_in_memory - cargo run --example deserialize_to_struct - cargo run --example expr_api - cargo run --example parquet_sql - cargo run --example parquet_sql_multiple_files - cargo run --example memtable - cargo run --example rewrite_expr - cargo run --example simple_udf - cargo run --example simple_udaf + ci/scripts/rust_example.sh - name: Verify Working Directory Clean run: git diff --exit-code @@ -211,12 +204,6 @@ jobs: image: amd64/rust steps: - uses: actions/checkout@v4 - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -238,12 +225,6 @@ jobs: - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -259,7 +240,8 @@ jobs: - name: Verify that benchmark queries return expected results run: | export TPCH_DATA=`realpath datafusion/sqllogictest/test_files/tpch/data` - cargo test serde_q --profile release-nonlto --features=ci -- --test-threads=1 + # use release build for plan verificaton because debug build causes stack overflow + cargo test plan_q --package datafusion-benchmarks --profile release-nonlto --features=ci -- --test-threads=1 INCLUDE_TPCH=true cargo test --test sqllogictests - name: Verify Working Directory Clean run: git diff --exit-code @@ -316,6 +298,7 @@ jobs: # with a OS-dependent path. - name: Setup Rust toolchain run: | + rustup update stable rustup toolchain install stable rustup default stable rustup component add rustfmt @@ -327,10 +310,13 @@ jobs: cd datafusion-cli cargo test --lib --tests --bins --all-features env: - # do not produce debug symbols to keep memory usage down - RUSTFLAGS: "-C debuginfo=0" + # Minimize producing debug symbols to keep memory usage down + # Set debuginfo=line-tables-only as debuginfo=0 causes immensely slow build + # See for more details: https://github.com/rust-lang/rust/issues/119560 + RUSTFLAGS: "-C debuginfo=line-tables-only" RUST_BACKTRACE: "1" - + # avoid rust stack overflows on tpc-ds tests + RUST_MIN_STACK: "3000000" macos: name: cargo test (mac) runs-on: macos-latest @@ -353,6 +339,7 @@ jobs: # with a OS-dependent path. - name: Setup Rust toolchain run: | + rustup update stable rustup toolchain install stable rustup default stable rustup component add rustfmt @@ -364,8 +351,12 @@ jobs: cargo test --lib --tests --bins --all-features env: # do not produce debug symbols to keep memory usage down - RUSTFLAGS: "-C debuginfo=0" + # hardcoding other profile params to avoid profile override values + # More on Cargo profiles https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings + RUSTFLAGS: "-C debuginfo=0 -C opt-level=0 -C incremental=false -C codegen-units=256" RUST_BACKTRACE: "1" + # avoid rust stack overflows on tpc-ds tests + RUST_MIN_STACK: "3000000" test-datafusion-pyarrow: name: cargo test pyarrow (amd64) @@ -377,13 +368,7 @@ jobs: - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: "3.8" - name: Install PyArrow @@ -480,12 +465,6 @@ jobs: - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -506,12 +485,6 @@ jobs: - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -531,12 +504,6 @@ jobs: - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -546,12 +513,11 @@ jobs: - name: Check Cargo.toml formatting run: | - # if you encounter error, try rerun the command below, finally run 'git diff' to - # check which Cargo.toml introduces formatting violation + # if you encounter an error, try running 'cargo tomlfmt -p path/to/Cargo.toml' to fix the formatting automatically. + # If the error still persists, you need to manually edit the Cargo.toml file, which introduces formatting violation. # # ignore ./Cargo.toml because putting workspaces in multi-line lists make it easy to read ci/scripts/rust_toml_fmt.sh - git diff --exit-code config-docs-check: name: check configs.md is up-to-date @@ -563,19 +529,13 @@ jobs: - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable - - uses: actions/setup-node@v3 + - uses: actions/setup-node@v4 with: - node-version: "14" + node-version: "20" - name: Check if configs.md has been modified run: | # If you encounter an error, run './dev/update_config_docs.sh' and commit diff --git a/Cargo.toml b/Cargo.toml index 60ff770d0d13..a87923b6a1a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,23 +17,7 @@ [workspace] exclude = ["datafusion-cli"] -members = [ - "datafusion/common", - "datafusion/core", - "datafusion/expr", - "datafusion/execution", - "datafusion/optimizer", - "datafusion/physical-expr", - "datafusion/physical-plan", - "datafusion/proto", - "datafusion/proto/gen", - "datafusion/sql", - "datafusion/sqllogictest", - "datafusion/substrait", - "datafusion/wasmtest", - "datafusion-examples", - "test-utils", - "benchmarks", +members = ["datafusion/common", "datafusion/core", "datafusion/expr", "datafusion/execution", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", "datafusion/wasmtest", "datafusion-examples", "docs", "test-utils", "benchmarks", ] resolver = "2" @@ -45,17 +29,51 @@ license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/arrow-datafusion" rust-version = "1.70" -version = "31.0.0" +version = "34.0.0" [workspace.dependencies] -arrow = { version = "47.0.0", features = ["prettyprint"] } -arrow-array = { version = "47.0.0", default-features = false, features = ["chrono-tz"] } -arrow-buffer = { version = "47.0.0", default-features = false } -arrow-flight = { version = "47.0.0", features = ["flight-sql-experimental"] } -arrow-schema = { version = "47.0.0", default-features = false } -parquet = { version = "47.0.0", features = ["arrow", "async", "object_store"] } -sqlparser = { version = "0.38.0", features = ["visitor"] } +arrow = { version = "49.0.0", features = ["prettyprint"] } +arrow-array = { version = "49.0.0", default-features = false, features = ["chrono-tz"] } +arrow-buffer = { version = "49.0.0", default-features = false } +arrow-flight = { version = "49.0.0", features = ["flight-sql-experimental"] } +arrow-ipc = { version = "49.0.0", default-features = false, features = ["lz4"] } +arrow-ord = { version = "49.0.0", default-features = false } +arrow-schema = { version = "49.0.0", default-features = false } +async-trait = "0.1.73" +bigdecimal = "0.4.1" +bytes = "1.4" chrono = { version = "0.4.31", default-features = false } +ctor = "0.2.0" +dashmap = "5.4.0" +datafusion = { path = "datafusion/core", version = "34.0.0" } +datafusion-common = { path = "datafusion/common", version = "34.0.0" } +datafusion-execution = { path = "datafusion/execution", version = "34.0.0" } +datafusion-expr = { path = "datafusion/expr", version = "34.0.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "34.0.0" } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "34.0.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "34.0.0" } +datafusion-proto = { path = "datafusion/proto", version = "34.0.0" } +datafusion-sql = { path = "datafusion/sql", version = "34.0.0" } +datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "34.0.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "34.0.0" } +doc-comment = "0.3" +env_logger = "0.10" +futures = "0.3" +half = "2.2.1" +indexmap = "2.0.0" +itertools = "0.12" +log = "^0.4" +num_cpus = "1.13.0" +object_store = { version = "0.8.0", default-features = false } +parking_lot = "0.12" +parquet = { version = "49.0.0", default-features = false, features = ["arrow", "async", "object_store"] } +rand = "0.8" +rstest = "0.18.0" +serde_json = "1" +sqlparser = { version = "0.41.0", features = ["visitor"] } +tempfile = "3" +thiserror = "1.0.44" +url = "2.2" [profile.release] codegen-units = 1 diff --git a/README.md b/README.md index ccb527a1f977..883700a39355 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ in-memory format. [Python Bindings](https://github.com/apache/arrow-datafusion-p Here are links to some important information - [Project Site](https://arrow.apache.org/datafusion) +- [Installation](https://arrow.apache.org/datafusion/user-guide/cli.html#installation) - [Rust Getting Started](https://arrow.apache.org/datafusion/user-guide/example-usage.html) - [Rust DataFrame API](https://arrow.apache.org/datafusion/user-guide/dataframe.html) - [Rust API docs](https://docs.rs/datafusion/latest/datafusion) @@ -35,12 +36,44 @@ Here are links to some important information - [Python DataFrame API](https://arrow.apache.org/datafusion-python/) - [Architecture](https://docs.rs/datafusion/latest/datafusion/index.html#architecture) -## Building your project with DataFusion +## What can you do with this crate? -DataFusion is great for building projects and products like SQL interfaces, time series platforms, and domain specific query engines. [Click Here](https://arrow.apache.org/datafusion/user-guide/introduction.html#known-users) to see a list known users. +DataFusion is great for building projects such as domain specific query engines, new database platforms and data pipelines, query languages and more. +It lets you start quickly from a fully working engine, and then customize those features specific to your use. [Click Here](https://arrow.apache.org/datafusion/user-guide/introduction.html#known-users) to see a list known users. ## Contributing to DataFusion -The [developer’s guide] contains information on how to contribute. +Please see the [developer’s guide] for contributing and [communication] for getting in touch with us. [developer’s guide]: https://arrow.apache.org/datafusion/contributor-guide/index.html#developer-s-guide +[communication]: https://arrow.apache.org/datafusion/contributor-guide/communication.html + +## Crate features + +This crate has several [features] which can be specified in your `Cargo.toml`. + +[features]: https://doc.rust-lang.org/cargo/reference/features.html + +Default features: + +- `compression`: reading files compressed with `xz2`, `bzip2`, `flate2`, and `zstd` +- `crypto_expressions`: cryptographic functions such as `md5` and `sha256` +- `encoding_expressions`: `encode` and `decode` functions +- `parquet`: support for reading the [Apache Parquet] format +- `regex_expressions`: regular expression functions, such as `regexp_match` +- `unicode_expressions`: Include unicode aware functions such as `character_length` + +Optional features: + +- `avro`: support for reading the [Apache Avro] format +- `backtrace`: include backtrace information in error messages +- `pyarrow`: conversions between PyArrow and DataFusion types +- `serde`: enable arrow-schema's `serde` feature +- `simd`: enable arrow-rs's manual `SIMD` kernels (requires Rust `nightly`) + +[apache avro]: https://avro.apache.org/ +[apache parquet]: https://parquet.apache.org/ + +## Rust Version Compatibility + +This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions of the Rust compiler. diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 1531b857bcef..4ce46968e1f4 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "datafusion-benchmarks" description = "DataFusion Benchmarks" -version = "31.0.0" +version = "34.0.0" edition = { workspace = true } authors = ["Apache Arrow "] homepage = "https://github.com/apache/arrow-datafusion" @@ -34,20 +34,20 @@ snmalloc = ["snmalloc-rs"] [dependencies] arrow = { workspace = true } -datafusion = { path = "../datafusion/core", version = "31.0.0" } -datafusion-common = { path = "../datafusion/common", version = "31.0.0" } -env_logger = "0.10" -futures = "0.3" -log = "^0.4" +datafusion = { path = "../datafusion/core", version = "34.0.0" } +datafusion-common = { path = "../datafusion/common", version = "34.0.0" } +env_logger = { workspace = true } +futures = { workspace = true } +log = { workspace = true } mimalloc = { version = "0.1", optional = true, default-features = false } -num_cpus = "1.13.0" -parquet = { workspace = true } +num_cpus = { workspace = true } +parquet = { workspace = true, default-features = true } serde = { version = "1.0.136", features = ["derive"] } -serde_json = "1.0.78" +serde_json = { workspace = true } snmalloc-rs = { version = "0.3", optional = true } structopt = { version = "0.3", default-features = false } test-utils = { path = "../test-utils/", version = "0.1.0" } tokio = { version = "^1.0", features = ["macros", "rt", "rt-multi-thread", "parking_lot"] } [dev-dependencies] -datafusion-proto = { path = "../datafusion/proto", version = "31.0.0" } +datafusion-proto = { path = "../datafusion/proto", version = "34.0.0" } diff --git a/benchmarks/compare.py b/benchmarks/compare.py index 80aa3c76b754..ec2b28fa0556 100755 --- a/benchmarks/compare.py +++ b/benchmarks/compare.py @@ -109,7 +109,6 @@ def compare( noise_threshold: float, ) -> None: baseline = BenchmarkRun.load_from_file(baseline_path) - comparison = BenchmarkRun.load_from_file(comparison_path) console = Console() @@ -124,27 +123,57 @@ def compare( table.add_column(comparison_header, justify="right", style="dim") table.add_column("Change", justify="right", style="dim") + faster_count = 0 + slower_count = 0 + no_change_count = 0 + total_baseline_time = 0 + total_comparison_time = 0 + for baseline_result, comparison_result in zip(baseline.queries, comparison.queries): assert baseline_result.query == comparison_result.query + total_baseline_time += baseline_result.execution_time + total_comparison_time += comparison_result.execution_time + change = comparison_result.execution_time / baseline_result.execution_time if (1.0 - noise_threshold) <= change <= (1.0 + noise_threshold): - change = "no change" + change_text = "no change" + no_change_count += 1 elif change < 1.0: - change = f"+{(1 / change):.2f}x faster" + change_text = f"+{(1 / change):.2f}x faster" + faster_count += 1 else: - change = f"{change:.2f}x slower" + change_text = f"{change:.2f}x slower" + slower_count += 1 table.add_row( f"Q{baseline_result.query}", f"{baseline_result.execution_time:.2f}ms", f"{comparison_result.execution_time:.2f}ms", - change, + change_text, ) console.print(table) + # Calculate averages + avg_baseline_time = total_baseline_time / len(baseline.queries) + avg_comparison_time = total_comparison_time / len(comparison.queries) + + # Summary table + summary_table = Table(show_header=True, header_style="bold magenta") + summary_table.add_column("Benchmark Summary", justify="left", style="dim") + summary_table.add_column("", justify="right", style="dim") + + summary_table.add_row(f"Total Time ({baseline_header})", f"{total_baseline_time:.2f}ms") + summary_table.add_row(f"Total Time ({comparison_header})", f"{total_comparison_time:.2f}ms") + summary_table.add_row(f"Average Time ({baseline_header})", f"{avg_baseline_time:.2f}ms") + summary_table.add_row(f"Average Time ({comparison_header})", f"{avg_comparison_time:.2f}ms") + summary_table.add_row("Queries Faster", str(faster_count)) + summary_table.add_row("Queries Slower", str(slower_count)) + summary_table.add_row("Queries with No Change", str(no_change_count)) + + console.print(summary_table) def main() -> None: parser = ArgumentParser() diff --git a/benchmarks/src/bin/h2o.rs b/benchmarks/src/bin/h2o.rs index d75f9a30b4e9..1bb8cb9d43e4 100644 --- a/benchmarks/src/bin/h2o.rs +++ b/benchmarks/src/bin/h2o.rs @@ -72,7 +72,7 @@ async fn group_by(opt: &GroupBy) -> Result<()> { let mut config = ConfigOptions::from_env()?; config.execution.batch_size = 65535; - let ctx = SessionContext::with_config(config.into()); + let ctx = SessionContext::new_with_config(config.into()); let schema = Schema::new(vec![ Field::new("id1", DataType::Utf8, false), diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index 98ef6dd805b0..a6d32eb39f31 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -81,7 +81,7 @@ impl RunOpt { }; let config = self.common.config(); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); self.register_hits(&ctx).await?; let iterations = self.common.iterations; diff --git a/benchmarks/src/parquet_filter.rs b/benchmarks/src/parquet_filter.rs index ceea12de9238..1d816908e2b0 100644 --- a/benchmarks/src/parquet_filter.rs +++ b/benchmarks/src/parquet_filter.rs @@ -19,8 +19,8 @@ use crate::AccessLogOpt; use crate::{BenchmarkRun, CommonOpt}; use arrow::util::pretty; use datafusion::common::Result; +use datafusion::logical_expr::utils::disjunction; use datafusion::logical_expr::{lit, or, Expr}; -use datafusion::optimizer::utils::disjunction; use datafusion::physical_plan::collect; use datafusion::prelude::{col, SessionContext}; use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; @@ -144,7 +144,7 @@ impl RunOpt { )); for i in 0..self.common.iterations { let config = self.common.update_config(scan_options.config()); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let (rows, elapsed) = exec_scan( &ctx, diff --git a/benchmarks/src/sort.rs b/benchmarks/src/sort.rs index d1baae868780..224f2b19c72e 100644 --- a/benchmarks/src/sort.rs +++ b/benchmarks/src/sort.rs @@ -148,9 +148,10 @@ impl RunOpt { println!("Executing '{title}' (sorting by: {expr:?})"); rundata.start_new_case(title); for i in 0..self.common.iterations { - let config = - SessionConfig::new().with_target_partitions(self.common.partitions); - let ctx = SessionContext::with_config(config); + let config = SessionConfig::new().with_target_partitions( + self.common.partitions.unwrap_or(num_cpus::get()), + ); + let ctx = SessionContext::new_with_config(config); let (rows, elapsed) = exec_sort(&ctx, &expr, &test_file, self.common.debug).await?; let ms = elapsed.as_secs_f64() * 1000.0; diff --git a/benchmarks/src/tpch/convert.rs b/benchmarks/src/tpch/convert.rs index f1ed081c43f5..2fc74ce38888 100644 --- a/benchmarks/src/tpch/convert.rs +++ b/benchmarks/src/tpch/convert.rs @@ -78,7 +78,7 @@ impl ConvertOpt { .file_extension(".tbl"); let config = SessionConfig::new().with_batch_size(self.batch_size); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); // build plan to read the TBL file let mut csv = ctx.read_csv(&input_path, options).await?; diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index cf5c7b9f67e3..5193d578fb48 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -110,7 +110,7 @@ impl RunOpt { .common .config() .with_collect_statistics(!self.disable_statistics); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); // register tables self.register_tables(&ctx).await?; @@ -285,7 +285,7 @@ impl RunOpt { } fn partitions(&self) -> usize { - self.common.partitions + self.common.partitions.unwrap_or(num_cpus::get()) } } @@ -325,7 +325,7 @@ mod tests { let path = get_tpch_data_path()?; let common = CommonOpt { iterations: 1, - partitions: 2, + partitions: Some(2), batch_size: 8192, debug: false, }; @@ -357,7 +357,7 @@ mod tests { let path = get_tpch_data_path()?; let common = CommonOpt { iterations: 1, - partitions: 2, + partitions: Some(2), batch_size: 8192, debug: false, }; diff --git a/benchmarks/src/util/options.rs b/benchmarks/src/util/options.rs index 1d86d10fb88c..b9398e5b522f 100644 --- a/benchmarks/src/util/options.rs +++ b/benchmarks/src/util/options.rs @@ -26,9 +26,9 @@ pub struct CommonOpt { #[structopt(short = "i", long = "iterations", default_value = "3")] pub iterations: usize, - /// Number of partitions to process in parallel - #[structopt(short = "n", long = "partitions", default_value = "2")] - pub partitions: usize, + /// Number of partitions to process in parallel. Defaults to number of available cores. + #[structopt(short = "n", long = "partitions")] + pub partitions: Option, /// Batch size when reading CSV or Parquet files #[structopt(short = "s", long = "batch-size", default_value = "8192")] @@ -48,7 +48,7 @@ impl CommonOpt { /// Modify the existing config appropriately pub fn update_config(&self, config: SessionConfig) -> SessionConfig { config - .with_target_partitions(self.partitions) + .with_target_partitions(self.partitions.unwrap_or(num_cpus::get())) .with_batch_size(self.batch_size) } } diff --git a/ci/scripts/rust_example.sh b/ci/scripts/rust_example.sh new file mode 100755 index 000000000000..fe3696f20865 --- /dev/null +++ b/ci/scripts/rust_example.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -ex +cd datafusion-examples/examples/ +cargo fmt --all -- --check + +files=$(ls .) +for filename in $files +do + example_name=`basename $filename ".rs"` + # Skip tests that rely on external storage and flight + # todo: Currently, catalog.rs is placed in the external-dependence directory because there is a problem parsing + # the parquet file of the external parquet-test that it currently relies on. + # We will wait for this issue[https://github.com/apache/arrow-datafusion/issues/8041] to be resolved. + if [ ! -d $filename ]; then + cargo run --example $example_name + fi +done diff --git a/ci/scripts/rust_toml_fmt.sh b/ci/scripts/rust_toml_fmt.sh index e297ef001594..0a8cc346a37d 100755 --- a/ci/scripts/rust_toml_fmt.sh +++ b/ci/scripts/rust_toml_fmt.sh @@ -17,5 +17,11 @@ # specific language governing permissions and limitations # under the License. +# Run cargo-tomlfmt with flag `-d` in dry run to check formatting +# without overwritng the file. If any error occur, you may want to +# rerun 'cargo tomlfmt -p path/to/Cargo.toml' without '-d' to fix +# the formatting automatically. set -ex -find . -mindepth 2 -name 'Cargo.toml' -exec cargo tomlfmt -p {} \; +for toml in $(find . -mindepth 2 -name 'Cargo.toml'); do + cargo tomlfmt -d -p $toml +done diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 0ca83452bd02..252b00ca0adc 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -17,24 +17,31 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler32" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" + [[package]] name = "ahash" -version = "0.8.3" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +checksum = "91429305e9f0a25f6205c5b8e0d2db09e0708a7a6df0f42212bb56c32c8ac97a" dependencies = [ "cfg-if", "const-random", "getrandom", "once_cell", "version_check", + "zerocopy", ] [[package]] name = "aho-corasick" -version = "1.1.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea5d730647d4fadd988536d06fecce94b7b4f2a7efdae548f1cf4b63205518ab" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" dependencies = [ "memchr", ] @@ -77,9 +84,37 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.3" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" + +[[package]] +name = "apache-avro" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b84bf0a05bbb2a83e5eb6fa36bb6e87baa08193c35ff52bbf6b38d8af2890e46" +checksum = "ceb7c683b2f8f40970b70e39ff8be514c95b96fcb9c4af87e1ed2cb2e10801a0" +dependencies = [ + "bzip2", + "crc32fast", + "digest", + "lazy_static", + "libflate", + "log", + "num-bigint", + "quad-rand", + "rand", + "regex-lite", + "serde", + "serde_json", + "snap", + "strum", + "strum_macros", + "thiserror", + "typed-builder", + "uuid", + "xz2", + "zstd 0.12.4", +] [[package]] name = "arrayref" @@ -95,9 +130,9 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "47.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fab9e93ba8ce88a37d5a30dce4b9913b75413dc1ac56cb5d72e5a840543f829" +checksum = "5bc25126d18a012146a888a0298f2c22e1150327bd2765fc76d710a556b2d614" dependencies = [ "ahash", "arrow-arith", @@ -117,9 +152,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "47.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc1d4e368e87ad9ee64f28b9577a3834ce10fe2703a26b28417d485bbbdff956" +checksum = "34ccd45e217ffa6e53bbb0080990e77113bdd4e91ddb84e97b77649810bcf1a7" dependencies = [ "arrow-array", "arrow-buffer", @@ -132,9 +167,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "47.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d02efa7253ede102d45a4e802a129e83bcc3f49884cab795b1ac223918e4318d" +checksum = "6bda9acea48b25123c08340f3a8ac361aa0f74469bb36f5ee9acf923fce23e9d" dependencies = [ "ahash", "arrow-buffer", @@ -143,15 +178,15 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.14.0", + "hashbrown 0.14.3", "num", ] [[package]] name = "arrow-buffer" -version = "47.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fda119225204141138cb0541c692fbfef0e875ba01bfdeaed09e9d354f9d6195" +checksum = "01a0fc21915b00fc6c2667b069c1b64bdd920982f426079bc4a7cab86822886c" dependencies = [ "bytes", "half", @@ -160,15 +195,16 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "47.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d825d51b9968868d50bc5af92388754056796dbc62a4e25307d588a1fc84dee" +checksum = "5dc0368ed618d509636c1e3cc20db1281148190a78f43519487b2daf07b63b4a" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", + "base64", "chrono", "comfy-table", "half", @@ -178,9 +214,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "47.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43ef855dc6b126dc197f43e061d4de46b9d4c033aa51c2587657f7508242cef1" +checksum = "2e09aa6246a1d6459b3f14baeaa49606cfdbca34435c46320e14054d244987ca" dependencies = [ "arrow-array", "arrow-buffer", @@ -197,9 +233,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "47.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "475a4c3699c8b4095ca61cecf15da6f67841847a5f5aac983ccb9a377d02f73a" +checksum = "907fafe280a3874474678c1858b9ca4cb7fd83fb8034ff5b6d6376205a08c634" dependencies = [ "arrow-buffer", "arrow-schema", @@ -209,9 +245,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "47.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1248005c8ac549f869b7a840859d942bf62471479c1a2d82659d453eebcd166a" +checksum = "79a43d6808411886b8c7d4f6f7dd477029c1e77ffffffb7923555cc6579639cd" dependencies = [ "arrow-array", "arrow-buffer", @@ -219,13 +255,14 @@ dependencies = [ "arrow-data", "arrow-schema", "flatbuffers", + "lz4_flex", ] [[package]] name = "arrow-json" -version = "47.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f03d7e3b04dd688ccec354fe449aed56b831679f03e44ee2c1cfc4045067b69c" +checksum = "d82565c91fd627922ebfe2810ee4e8346841b6f9361b87505a9acea38b614fee" dependencies = [ "arrow-array", "arrow-buffer", @@ -234,7 +271,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.0.0", + "indexmap 2.1.0", "lexical-core", "num", "serde", @@ -243,9 +280,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "47.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03b87aa408ea6a6300e49eb2eba0c032c88ed9dc19e0a9948489c55efdca71f4" +checksum = "9b23b0e53c0db57c6749997fd343d4c0354c994be7eca67152dd2bdb9a3e1bb4" dependencies = [ "arrow-array", "arrow-buffer", @@ -258,9 +295,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "47.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "114a348ab581e7c9b6908fcab23cb39ff9f060eb19e72b13f8fb8eaa37f65d22" +checksum = "361249898d2d6d4a6eeb7484be6ac74977e48da12a4dd81a708d620cc558117a" dependencies = [ "ahash", "arrow-array", @@ -268,20 +305,20 @@ dependencies = [ "arrow-data", "arrow-schema", "half", - "hashbrown 0.14.0", + "hashbrown 0.14.3", ] [[package]] name = "arrow-schema" -version = "47.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d1d179c117b158853e0101bfbed5615e86fe97ee356b4af901f1c5001e1ce4b" +checksum = "09e28a5e781bf1b0f981333684ad13f5901f4cd2f20589eab7cf1797da8fc167" [[package]] name = "arrow-select" -version = "47.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5c71e003202e67e9db139e5278c79f5520bb79922261dfe140e4637ee8b6108" +checksum = "4f6208466590960efc1d2a7172bc4ff18a67d6e25c529381d7f96ddaf0dc4036" dependencies = [ "ahash", "arrow-array", @@ -293,9 +330,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "47.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4cebbb282d6b9244895f4a9a912e55e57bce112554c7fa91fcec5459cb421ab" +checksum = "a4a48149c63c11c9ff571e50ab8f017d2a7cb71037a882b42f6354ed2da9acc7" dependencies = [ "arrow-array", "arrow-buffer", @@ -324,9 +361,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.3" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb42b2197bf15ccb092b62c74515dbd8b86d0effd934795f6687c93b6e679a2c" +checksum = "bc2d0cfb2a7388d34f590e76686704c494ed7aaceed62ee1ba35cbf363abc2a5" dependencies = [ "bzip2", "flate2", @@ -336,19 +373,19 @@ dependencies = [ "pin-project-lite", "tokio", "xz2", - "zstd", - "zstd-safe", + "zstd 0.13.0", + "zstd-safe 7.0.0", ] [[package]] name = "async-trait" -version = "0.1.73" +version = "0.1.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0" +checksum = "fdf6721fb0140e4f897002dd086c06f6c27775df19cfe1fccb21181a48fd2c98" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.43", ] [[package]] @@ -390,7 +427,7 @@ dependencies = [ "hex", "http", "hyper", - "ring", + "ring 0.16.20", "time", "tokio", "tower", @@ -675,9 +712,9 @@ dependencies = [ [[package]] name = "base64" -version = "0.21.4" +version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" +checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" [[package]] name = "base64-simd" @@ -697,9 +734,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" [[package]] name = "blake2" @@ -734,9 +771,9 @@ dependencies = [ [[package]] name = "brotli" -version = "3.3.4" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1a0b1dbcc8ae29329621f8d4f0d835787c1c38bb1401979b49d13b0b305ff68" +checksum = "516074a47ef4bce09577a3b379392300159ce5b1ba2e501ff1c819950066100f" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -745,9 +782,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "2.3.4" +version = "2.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b6561fd3f895a11e8f72af2cb7d22e08366bebc2b6b57f7744c4bda27034744" +checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -755,9 +792,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.6.2" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c2f7349907b712260e64b0afe2f84692af14a454be26187d9df565c7f69266a" +checksum = "542f33a8835a0884b006a0c3df3dadd99c0c3f296ed26c2fdc8028e01ad6230c" dependencies = [ "memchr", "regex-automata", @@ -772,9 +809,9 @@ checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" [[package]] name = "byteorder" -version = "1.4.3" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" @@ -784,9 +821,9 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "bytes-utils" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e47d3a8076e283f3acd27400535992edb3ba4b5bb72f8891ad8fbe7932a7d4b9" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" dependencies = [ "bytes", "either", @@ -839,14 +876,14 @@ dependencies = [ "iana-time-zone", "num-traits", "serde", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] name = "chrono-tz" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1369bc6b9e9a7dfdae2055f6ec151fe9c554a9d23d357c0237cee2e25eaabb7" +checksum = "e23185c0e21df6ed832a12e2bda87c7d1def6842881fb634a8511ced741b0d76" dependencies = [ "chrono", "chrono-tz-build", @@ -855,9 +892,9 @@ dependencies = [ [[package]] name = "chrono-tz-build" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2f5ebdc942f57ed96d560a6d1a459bae5851102a25d5bf89dc04ae453e31ecf" +checksum = "433e39f13c9a060046954e0592a8d0a4bcb1040125cbf91cb8ee58964cfb350f" dependencies = [ "parse-zoneinfo", "phf", @@ -916,34 +953,32 @@ dependencies = [ [[package]] name = "comfy-table" -version = "7.0.1" +version = "7.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ab77dbd8adecaf3f0db40581631b995f312a8a5ae3aa9993188bb8f23d83a5b" +checksum = "7c64043d6c7b7a4c58e39e7efccfdea7b93d885a795d0c054a69dbbf4dd52686" dependencies = [ - "strum 0.24.1", - "strum_macros 0.24.3", + "strum", + "strum_macros", "unicode-width", ] [[package]] name = "const-random" -version = "0.1.15" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368a7a772ead6ce7e1de82bfb04c485f3db8ec744f72925af5735e29a22cc18e" +checksum = "5aaf16c9c2c612020bcfd042e170f6e32de9b9d75adb5277cdbbd2e2c8c8299a" dependencies = [ "const-random-macro", - "proc-macro-hack", ] [[package]] name = "const-random-macro" -version = "0.1.15" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d7d6ab3c3a2282db210df5f02c4dab6e0a7057af0fb7ebd4070f30fe05c0ddb" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ "getrandom", "once_cell", - "proc-macro-hack", "tiny-keccak", ] @@ -955,9 +990,9 @@ checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" [[package]] name = "core-foundation" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" dependencies = [ "core-foundation-sys", "libc", @@ -965,15 +1000,24 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.4" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "core2" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +checksum = "b49ba7ef1ad6107f8824dbe97de947cbaac53c44e7f9756a1fba0d37c1eec505" +dependencies = [ + "memchr", +] [[package]] name = "cpufeatures" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a17b76ff3a4162b0b27f354a0c87015ddad39d35f9c0c36607a3bdd175dde1f1" +checksum = "ce420fe07aecd3e67c5f910618fe65e94158f6dcc0adf44e00d69ce2bdfe0fd0" dependencies = [ "libc", ] @@ -1005,9 +1049,9 @@ dependencies = [ [[package]] name = "csv" -version = "1.2.2" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "626ae34994d3d8d668f4269922248239db4ae42d538b14c398b74a52208e8086" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" dependencies = [ "csv-core", "itoa", @@ -1017,23 +1061,29 @@ dependencies = [ [[package]] name = "csv-core" -version = "0.1.10" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" dependencies = [ "memchr", ] [[package]] name = "ctor" -version = "0.2.4" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f34ba9a9bcb8645379e9de8cb3ecfcf4d1c85ba66d90deb3259206fa5aa193b" +checksum = "30d2b3721e861707777e3195b0158f950ae6dc4a27e4d02ff9f67e3eb3de199e" dependencies = [ "quote", - "syn 2.0.37", + "syn 2.0.43", ] +[[package]] +name = "dary_heap" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca" + [[package]] name = "dashmap" version = "5.5.3" @@ -1041,7 +1091,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" dependencies = [ "cfg-if", - "hashbrown 0.14.0", + "hashbrown 0.14.3", "lock_api", "once_cell", "parking_lot_core", @@ -1049,11 +1099,13 @@ dependencies = [ [[package]] name = "datafusion" -version = "31.0.0" +version = "34.0.0" dependencies = [ "ahash", + "apache-avro", "arrow", "arrow-array", + "arrow-ipc", "arrow-schema", "async-compression", "async-trait", @@ -1072,15 +1124,15 @@ dependencies = [ "futures", "glob", "half", - "hashbrown 0.14.0", - "indexmap 2.0.0", - "itertools 0.11.0", + "hashbrown 0.14.3", + "indexmap 2.1.0", + "itertools 0.12.0", "log", + "num-traits", "num_cpus", "object_store", "parking_lot", "parquet", - "percent-encoding", "pin-project-lite", "rand", "sqlparser", @@ -1090,12 +1142,12 @@ dependencies = [ "url", "uuid", "xz2", - "zstd", + "zstd 0.13.0", ] [[package]] name = "datafusion-cli" -version = "31.0.0" +version = "34.0.0" dependencies = [ "arrow", "assert_cmd", @@ -1105,11 +1157,14 @@ dependencies = [ "clap", "ctor", "datafusion", + "datafusion-common", "dirs", "env_logger", + "futures", "mimalloc", "object_store", "parking_lot", + "parquet", "predicates", "regex", "rstest", @@ -1120,11 +1175,17 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "31.0.0" +version = "34.0.0" dependencies = [ + "ahash", + "apache-avro", "arrow", "arrow-array", + "arrow-buffer", + "arrow-schema", "chrono", + "half", + "libc", "num_cpus", "object_store", "parquet", @@ -1133,7 +1194,7 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "31.0.0" +version = "34.0.0" dependencies = [ "arrow", "chrono", @@ -1141,7 +1202,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "futures", - "hashbrown 0.14.0", + "hashbrown 0.14.3", "log", "object_store", "parking_lot", @@ -1152,20 +1213,21 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "31.0.0" +version = "34.0.0" dependencies = [ "ahash", "arrow", "arrow-array", "datafusion-common", + "paste", "sqlparser", - "strum 0.25.0", - "strum_macros 0.25.2", + "strum", + "strum_macros", ] [[package]] name = "datafusion-optimizer" -version = "31.0.0" +version = "34.0.0" dependencies = [ "arrow", "async-trait", @@ -1173,20 +1235,21 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "hashbrown 0.14.0", - "itertools 0.11.0", + "hashbrown 0.14.3", + "itertools 0.12.0", "log", "regex-syntax", ] [[package]] name = "datafusion-physical-expr" -version = "31.0.0" +version = "34.0.0" dependencies = [ "ahash", "arrow", "arrow-array", "arrow-buffer", + "arrow-ord", "arrow-schema", "base64", "blake2", @@ -1195,11 +1258,10 @@ dependencies = [ "datafusion-common", "datafusion-expr", "half", - "hashbrown 0.14.0", + "hashbrown 0.14.3", "hex", - "indexmap 2.0.0", - "itertools 0.11.0", - "libc", + "indexmap 2.1.0", + "itertools 0.12.0", "log", "md-5", "paste", @@ -1213,7 +1275,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "31.0.0" +version = "34.0.0" dependencies = [ "ahash", "arrow", @@ -1228,9 +1290,9 @@ dependencies = [ "datafusion-physical-expr", "futures", "half", - "hashbrown 0.14.0", - "indexmap 2.0.0", - "itertools 0.11.0", + "hashbrown 0.14.3", + "indexmap 2.1.0", + "itertools 0.12.0", "log", "once_cell", "parking_lot", @@ -1242,7 +1304,7 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "31.0.0" +version = "34.0.0" dependencies = [ "arrow", "arrow-schema", @@ -1254,9 +1316,12 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.8" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2696e8a945f658fd14dc3b87242e6b80cd0f36ff04ea560fa39082368847946" +checksum = "8eb30d70a07a3b04884d2677f06bec33509dc67ca60d92949e5535352d3191dc" +dependencies = [ + "powerfmt", +] [[package]] name = "difflib" @@ -1364,23 +1429,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "136526188508e25c6fef639d7927dfb3e0e3084488bf202267829cf7fc23dbdd" -dependencies = [ - "errno-dragonfly", - "libc", - "windows-sys", -] - -[[package]] -name = "errno-dragonfly" -version = "0.1.2" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ - "cc", "libc", + "windows-sys 0.52.0", ] [[package]] @@ -1404,9 +1458,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.0.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6999dc1837253364c2ebb0704ba97994bd874e8f195d665c50b7548f6ea92764" +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" [[package]] name = "fd-lock" @@ -1416,7 +1470,7 @@ checksum = "ef033ed5e9bad94e55838ca0ca906db0e043f517adda0c8b79c7a8c66c93c1b5" dependencies = [ "cfg-if", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1437,9 +1491,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.27" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6c98ee8095e9d1dcbf2fcc6d95acccb90d1c81db1e44725c6a984b1dbdfb010" +checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" dependencies = [ "crc32fast", "miniz_oxide", @@ -1462,18 +1516,18 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "form_urlencoded" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" dependencies = [ "percent-encoding", ] [[package]] name = "futures" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" dependencies = [ "futures-channel", "futures-core", @@ -1486,9 +1540,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", "futures-sink", @@ -1496,15 +1550,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" [[package]] name = "futures-executor" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" dependencies = [ "futures-core", "futures-task", @@ -1513,32 +1567,32 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" [[package]] name = "futures-macro" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.43", ] [[package]] name = "futures-sink" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" [[package]] name = "futures-task" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" [[package]] name = "futures-timer" @@ -1548,9 +1602,9 @@ checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" [[package]] name = "futures-util" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-channel", "futures-core", @@ -1576,9 +1630,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" dependencies = [ "cfg-if", "libc", @@ -1587,9 +1641,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.28.0" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "glob" @@ -1599,9 +1653,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "h2" -version = "0.3.21" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91fc23aa11be92976ef4729127f1a74adf36d8436f7816b185d18df956790833" +checksum = "4d6250322ef6e60f93f9a2162799302cd6f68f79f6e5d85c8c16f14d1d958178" dependencies = [ "bytes", "fnv", @@ -1609,7 +1663,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap 1.9.3", + "indexmap 2.1.0", "slab", "tokio", "tokio-util", @@ -1635,9 +1689,18 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" -version = "0.14.0" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashbrown" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" dependencies = [ "ahash", "allocator-api2", @@ -1681,9 +1744,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" dependencies = [ "bytes", "fnv", @@ -1692,9 +1755,9 @@ dependencies = [ [[package]] name = "http-body" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", "http", @@ -1721,9 +1784,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.27" +version = "0.14.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" +checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" dependencies = [ "bytes", "futures-channel", @@ -1736,7 +1799,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.4.9", + "socket2", "tokio", "tower-service", "tracing", @@ -1760,30 +1823,30 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d78e1e73ec14cf7375674f74d7dde185c8206fd9dea6fb6295e8a98098aaa97" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", "http", "hyper", - "rustls 0.21.7", + "rustls 0.21.10", "tokio", "tokio-rustls 0.24.1", ] [[package]] name = "iana-time-zone" -version = "0.1.57" +version = "0.1.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fad5b825842d2b38bd206f3e81d6957625fd7f0a361e345c30e01a0ae2dd613" +checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows", + "windows-core", ] [[package]] @@ -1797,9 +1860,9 @@ dependencies = [ [[package]] name = "idna" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" dependencies = [ "unicode-bidi", "unicode-normalization", @@ -1817,12 +1880,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.0.0" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" dependencies = [ "equivalent", - "hashbrown 0.14.0", + "hashbrown 0.14.3", ] [[package]] @@ -1842,48 +1905,48 @@ checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "ipnet" -version = "2.8.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" +checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" [[package]] name = "itertools" -version = "0.10.5" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" dependencies = [ "either", ] [[package]] name = "itertools" -version = "0.11.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" dependencies = [ "either", ] [[package]] name = "itoa" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "jobserver" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "936cfd212a0155903bcbc060e316fb6cc7cbf2e1907329391ebadc1fe0ce77c2" +checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" dependencies = [ "libc", ] [[package]] name = "js-sys" -version = "0.3.64" +version = "0.3.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a" +checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca" dependencies = [ "wasm-bindgen", ] @@ -1960,15 +2023,39 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.148" +version = "0.2.151" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" + +[[package]] +name = "libflate" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7d5654ae1795afc7ff76f4365c2c8791b0feb18e8996a96adad8ffd7c3b2bf" +dependencies = [ + "adler32", + "core2", + "crc32fast", + "dary_heap", + "libflate_lz77", +] + +[[package]] +name = "libflate_lz77" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cdc71e17332e86d2e1d38c1f99edcb6288ee11b815fb1a4b049eaa2114d369b" +checksum = "be5f52fb8c451576ec6b79d3f4deb327398bc05bbdbd99021a6e77a4c855d524" +dependencies = [ + "core2", + "hashbrown 0.13.2", + "rle-decode-fast", +] [[package]] name = "libm" -version = "0.2.7" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "libmimalloc-sys" @@ -1980,17 +2067,28 @@ dependencies = [ "libc", ] +[[package]] +name = "libredox" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" +dependencies = [ + "bitflags 2.4.1", + "libc", + "redox_syscall", +] + [[package]] name = "linux-raw-sys" -version = "0.4.7" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a9bad9f94746442c783ca431b22403b519cd7fbeed0533fdd6328b2f2212128" +checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "lock_api" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" dependencies = [ "autocfg", "scopeguard", @@ -2003,23 +2101,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] -name = "lz4" -version = "1.24.0" +name = "lz4_flex" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e9e2dd86df36ce760a60f6ff6ad526f7ba1f14ba0356f8254fb6905e6494df1" +checksum = "3ea9b256699eda7b0387ffbc776dd625e28bde3918446381781245b7a50349d8" dependencies = [ - "libc", - "lz4-sys", -] - -[[package]] -name = "lz4-sys" -version = "1.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57d27b317e207b10f69f5e75494119e391a96f48861ae870d1da6edac98ca900" -dependencies = [ - "cc", - "libc", + "twox-hash", ] [[package]] @@ -2035,18 +2122,19 @@ dependencies = [ [[package]] name = "md-5" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6365506850d44bff6e2fbcb5176cf63650e48bd45ef2fe2665ae1570e0f4b9ca" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" dependencies = [ + "cfg-if", "digest", ] [[package]] name = "memchr" -version = "2.6.3" +version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" [[package]] name = "mimalloc" @@ -2074,13 +2162,13 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.8" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" dependencies = [ "libc", "wasi", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2178,9 +2266,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", "libm", @@ -2198,18 +2286,18 @@ dependencies = [ [[package]] name = "object" -version = "0.32.1" +version = "0.32.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" dependencies = [ "memchr", ] [[package]] name = "object_store" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d359e231e5451f4f9fa889d56e3ce34f8724f1a61db2107739359717cf2bbf08" +checksum = "2524735495ea1268be33d200e1ee97455096a0846295a21548cd2f3541de7050" dependencies = [ "async-trait", "base64", @@ -2218,13 +2306,13 @@ dependencies = [ "futures", "humantime", "hyper", - "itertools 0.10.5", + "itertools 0.11.0", "parking_lot", "percent-encoding", "quick-xml", "rand", "reqwest", - "ring", + "ring 0.17.7", "rustls-pemfile", "serde", "serde_json", @@ -2237,9 +2325,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "openssl-probe" @@ -2249,18 +2337,18 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "ordered-float" -version = "2.10.0" +version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7940cf2ca942593318d07fcf2596cdca60a85c9e7fab408a5e21a4f9dcd40d87" +checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" dependencies = [ "num-traits", ] [[package]] name = "os_str_bytes" -version = "6.5.1" +version = "6.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d5d9eb14b174ee9aa2ef96dc2b94637a2d4b6e7cb873c7e171f0c20c6cf3eac" +checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1" [[package]] name = "outref" @@ -2280,22 +2368,22 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.8" +version = "0.9.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.3.5", + "redox_syscall", "smallvec", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] name = "parquet" -version = "47.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0463cc3b256d5f50408c49a4be3a16674f4c8ceef60941709620a062b1f6bf4d" +checksum = "af88740a842787da39b3d69ce5fbf6fce97d20211d3b299fee0a0da6430c74d4" dependencies = [ "ahash", "arrow-array", @@ -2311,8 +2399,8 @@ dependencies = [ "chrono", "flate2", "futures", - "hashbrown 0.14.0", - "lz4", + "hashbrown 0.14.3", + "lz4_flex", "num", "num-bigint", "object_store", @@ -2322,7 +2410,7 @@ dependencies = [ "thrift", "tokio", "twox-hash", - "zstd", + "zstd 0.13.0", ] [[package]] @@ -2342,9 +2430,9 @@ checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" [[package]] name = "percent-encoding" -version = "2.3.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "petgraph" @@ -2353,7 +2441,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" dependencies = [ "fixedbitset", - "indexmap 2.0.0", + "indexmap 2.1.0", ] [[package]] @@ -2411,7 +2499,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.43", ] [[package]] @@ -2428,9 +2516,15 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.27" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" + +[[package]] +name = "powerfmt" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" @@ -2493,26 +2587,26 @@ dependencies = [ "version_check", ] -[[package]] -name = "proc-macro-hack" -version = "0.5.20+deprecated" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" - [[package]] name = "proc-macro2" -version = "1.0.67" +version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d433d9f1a3e8c1263d9456598b16fec66f4acc9a74dacffd35c7bb09b3a1328" +checksum = "75cb1540fadbd5b8fbccc4dddad2734eba435053f725621c070711a14bb5f4b8" dependencies = [ "unicode-ident", ] +[[package]] +name = "quad-rand" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658fa1faf7a4cc5f057c9ee5ef560f717ad9d8dc66d975267f709624d6e1ab88" + [[package]] name = "quick-xml" -version = "0.28.2" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce5e73202a820a31f8a0ee32ada5e21029c81fd9e3ebf668a40832e4219d9d1" +checksum = "1004a344b30a54e2ee58d66a71b32d2db2feb0a31f9a2d302bf0536f15de2a33" dependencies = [ "memchr", "serde", @@ -2569,38 +2663,29 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" -dependencies = [ - "bitflags 1.3.2", -] - -[[package]] -name = "redox_syscall" -version = "0.3.5" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" dependencies = [ "bitflags 1.3.2", ] [[package]] name = "redox_users" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" +checksum = "a18479200779601e498ada4e8c1e1f50e3ee19deb0259c25825a98b5603b2cb4" dependencies = [ "getrandom", - "redox_syscall 0.2.16", + "libredox", "thiserror", ] [[package]] name = "regex" -version = "1.9.5" +version = "1.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" +checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" dependencies = [ "aho-corasick", "memchr", @@ -2610,26 +2695,32 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.8" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" dependencies = [ "aho-corasick", "memchr", "regex-syntax", ] +[[package]] +name = "regex-lite" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b661b2f27137bdbc16f00eda72866a92bb28af1753ffbd56744fb6e2e9cd8e" + [[package]] name = "regex-syntax" -version = "0.7.5" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "reqwest" -version = "0.11.20" +version = "0.11.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e9ad3fe7488d7e34558a2033d45a0c90b72d97b4f80705666fea71472e2e6a1" +checksum = "37b1ae8d9ac08420c66222fb9096fc5de435c3c48542bc5336c51892cffafb41" dependencies = [ "base64", "bytes", @@ -2640,7 +2731,7 @@ dependencies = [ "http", "http-body", "hyper", - "hyper-rustls 0.24.1", + "hyper-rustls 0.24.2", "ipnet", "js-sys", "log", @@ -2648,11 +2739,12 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls 0.21.7", + "rustls 0.21.10", "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", + "system-configuration", "tokio", "tokio-rustls 0.24.1", "tokio-util", @@ -2675,12 +2767,32 @@ dependencies = [ "cc", "libc", "once_cell", - "spin", - "untrusted", + "spin 0.5.2", + "untrusted 0.7.1", "web-sys", "winapi", ] +[[package]] +name = "ring" +version = "0.17.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" +dependencies = [ + "cc", + "getrandom", + "libc", + "spin 0.9.8", + "untrusted 0.9.0", + "windows-sys 0.48.0", +] + +[[package]] +name = "rle-decode-fast" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" + [[package]] name = "rstest" version = "0.17.0" @@ -2724,15 +2836,15 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.14" +version = "0.38.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "747c788e9ce8e92b12cd485c49ddf90723550b654b32508f979b71a7b1ecda4f" +checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" dependencies = [ - "bitflags 2.4.0", + "bitflags 2.4.1", "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -2742,19 +2854,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b80e3dec595989ea8510028f30c408a4630db12c9cbb8de34203b89d6577e99" dependencies = [ "log", - "ring", + "ring 0.16.20", "sct", "webpki", ] [[package]] name = "rustls" -version = "0.21.7" +version = "0.21.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8" +checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" dependencies = [ "log", - "ring", + "ring 0.17.7", "rustls-webpki", "sct", ] @@ -2773,21 +2885,21 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ "base64", ] [[package]] name = "rustls-webpki" -version = "0.101.5" +version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45a27e3b59326c16e23d30aeb7a36a24cc0d29e71d68ff611cdfb4a01d013bed" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "ring", - "untrusted", + "ring 0.17.7", + "untrusted 0.9.0", ] [[package]] @@ -2821,9 +2933,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.15" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" [[package]] name = "same-file" @@ -2840,7 +2952,7 @@ version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" dependencies = [ - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2851,12 +2963,12 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "sct" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ - "ring", - "untrusted", + "ring 0.17.7", + "untrusted 0.9.0", ] [[package]] @@ -2884,9 +2996,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.18" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0293b4b29daaf487284529cc2f5675b8e57c61f70167ba415a463651fd6a918" +checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" [[package]] name = "seq-macro" @@ -2896,29 +3008,29 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.188" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.188" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.43", ] [[package]] name = "serde_json" -version = "1.0.107" +version = "1.0.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" +checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" dependencies = [ "itoa", "ryu", @@ -2939,15 +3051,24 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.7" +version = "0.10.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479fb9d862239e610720565ca91403019f2f00410f1864c5aa7479b950a76ed8" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" dependencies = [ "cfg-if", "cpufeatures", "digest", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + [[package]] name = "siphasher" version = "0.3.11" @@ -2965,9 +3086,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.1" +version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" [[package]] name = "snafu" @@ -2993,41 +3114,37 @@ dependencies = [ [[package]] name = "snap" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e9f0ab6ef7eb7353d9119c170a436d1bf248eea575ac42d19d12f4e34130831" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" -version = "0.4.9" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", - "winapi", + "windows-sys 0.48.0", ] [[package]] -name = "socket2" -version = "0.5.4" +name = "spin" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4031e820eb552adee9295814c0ced9e5cf38ddf1e8b7d566d6de8e2538ea989e" -dependencies = [ - "libc", - "windows-sys", -] +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" [[package]] name = "spin" -version = "0.5.2" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.38.0" +version = "0.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0272b7bb0a225320170c99901b4b5fb3a4384e255a7f2cc228f61e2ba3893e75" +checksum = "5cc2c25a6c66789625ef164b4c7d2e548d627902280c13710d33da8222169964" dependencies = [ "log", "sqlparser_derive", @@ -3035,9 +3152,9 @@ dependencies = [ [[package]] name = "sqlparser_derive" -version = "0.1.1" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55fe75cb4a364c7f7ae06c7dbbc8d84bddd85d6cdf9975963c3935bc1991761e" +checksum = "3e9c2e1dde0efa87003e7923d94a90f46e3274ad1649f51de96812be561f041f" dependencies = [ "proc-macro2", "quote", @@ -3062,45 +3179,26 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" -[[package]] -name = "strum" -version = "0.24.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f" - [[package]] name = "strum" version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" dependencies = [ - "strum_macros 0.25.2", -] - -[[package]] -name = "strum_macros" -version = "0.24.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e385be0d24f186b4ce2f9982191e7101bb737312ad61c1f2f984f34bcf85d59" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "rustversion", - "syn 1.0.109", + "strum_macros", ] [[package]] name = "strum_macros" -version = "0.25.2" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad8d03b598d3d0fff69bf533ee3ef19b8eeb342729596df84bcc7e1f96ec4059" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" dependencies = [ "heck", "proc-macro2", "quote", "rustversion", - "syn 2.0.37", + "syn 2.0.43", ] [[package]] @@ -3122,33 +3220,54 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.37" +version = "2.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7303ef2c05cd654186cb250d29049a24840ca25d2747c25c0381c8d9e2f582e8" +checksum = "ee659fb5f3d355364e1f3e5bc10fb82068efbf824a1e9d1c9504244a6469ad53" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tempfile" -version = "3.8.0" +version = "3.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb94d2f3cc536af71caac6b6fcebf65860b347e7ce0cc9ebe8f70d3e521054ef" +checksum = "7ef1adac450ad7f4b3c28589471ade84f25f731a7a0fe30d71dfa9f60fd808e5" dependencies = [ "cfg-if", - "fastrand 2.0.0", - "redox_syscall 0.3.5", + "fastrand 2.0.1", + "redox_syscall", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "termcolor" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64" +checksum = "ff1bc3d3f05aff0403e8ac0d92ced918ec05b666a43f83297ccef5bea8a3d449" dependencies = [ "winapi-util", ] @@ -3167,22 +3286,22 @@ checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" [[package]] name = "thiserror" -version = "1.0.48" +version = "1.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" +checksum = "83a48fd946b02c0a526b2e9481c8e2a17755e47039164a86c4070446e3a4614d" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.48" +version = "1.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" +checksum = "e7fbe9b594d6568a6a1443250a7e67d80b74e1e96f6d1715e1e21cc1888291d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.43", ] [[package]] @@ -3198,11 +3317,12 @@ dependencies = [ [[package]] name = "time" -version = "0.3.28" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17f6bb557fd245c28e6411aa56b6403c689ad95061f50e4be16c274e70a17e48" +checksum = "f657ba42c3f86e7680e53c8cd3af8abbe56b5491790b46e22e19c0d57463583e" dependencies = [ "deranged", + "powerfmt", "serde", "time-core", "time-macros", @@ -3210,15 +3330,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.14" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a942f44339478ef67935ab2bbaec2fb0322496cf3cbe84b261e06ac3814c572" +checksum = "26197e33420244aeb70c3e8c78376ca46571bc4e701e4791c2cd9f57dcb3a43f" dependencies = [ "time-core", ] @@ -3249,9 +3369,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.32.0" +version = "1.35.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17ed6077ed6cd6c74735e21f37eb16dc3935f96878b1fe961074089cc80893f9" +checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" dependencies = [ "backtrace", "bytes", @@ -3260,20 +3380,21 @@ dependencies = [ "num_cpus", "parking_lot", "pin-project-lite", - "socket2 0.5.4", + "signal-hook-registry", + "socket2", "tokio-macros", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.43", ] [[package]] @@ -3293,7 +3414,7 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls 0.21.7", + "rustls 0.21.10", "tokio", ] @@ -3310,9 +3431,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.9" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d68074620f57a0b21594d9735eb2e98ab38b17f80d3fcb189fca266771ca60d" +checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" dependencies = [ "bytes", "futures-core", @@ -3352,11 +3473,10 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.37" +version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ - "cfg-if", "log", "pin-project-lite", "tracing-attributes", @@ -3365,29 +3485,29 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.43", ] [[package]] name = "tracing-core" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", ] [[package]] name = "try-lock" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "twox-hash" @@ -3399,6 +3519,26 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "typed-builder" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34085c17941e36627a879208083e25d357243812c30e7d7387c3b954f30ade16" +dependencies = [ + "typed-builder-macro", +] + +[[package]] +name = "typed-builder-macro" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.43", +] + [[package]] name = "typenum" version = "1.17.0" @@ -3407,9 +3547,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-bidi" -version = "0.3.13" +version = "0.3.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" +checksum = "6f2528f27a9eb2b21e69c95319b30bd0efd85d09c379741b0f78ea1d86be2416" [[package]] name = "unicode-ident" @@ -3444,11 +3584,17 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", "idna", @@ -3469,11 +3615,12 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.4.1" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" +checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" dependencies = [ "getrandom", + "serde", ] [[package]] @@ -3524,9 +3671,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.87" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" +checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -3534,24 +3681,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.87" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" +checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.43", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.37" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03" +checksum = "ac36a15a220124ac510204aec1c3e5db8a22ab06fd6706d881dc6149f8ed9a12" dependencies = [ "cfg-if", "js-sys", @@ -3561,9 +3708,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.87" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" +checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3571,22 +3718,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.87" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" +checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.43", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.87" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" [[package]] name = "wasm-streams" @@ -3603,9 +3750,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.64" +version = "0.3.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" +checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f" dependencies = [ "js-sys", "wasm-bindgen", @@ -3613,19 +3760,19 @@ dependencies = [ [[package]] name = "webpki" -version = "0.22.1" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0e74f82d49d545ad128049b7e88f6576df2da6b02e9ce565c6f533be576957e" +checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" dependencies = [ - "ring", - "untrusted", + "ring 0.17.7", + "untrusted 0.9.0", ] [[package]] name = "webpki-roots" -version = "0.25.2" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc" +checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" [[package]] name = "winapi" @@ -3659,12 +3806,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] -name = "windows" -version = "0.48.0" +name = "windows-core" +version = "0.51.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" +checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -3673,7 +3820,16 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.0", ] [[package]] @@ -3682,13 +3838,28 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +dependencies = [ + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] [[package]] @@ -3697,42 +3868,84 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" + [[package]] name = "winreg" version = "0.50.0" @@ -3740,14 +3953,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" dependencies = [ "cfg-if", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "xmlparser" -version = "0.13.5" +version = "0.13.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d25c75bf9ea12c4040a97f829154768bbbce366287e2dc044af160cd79a13fd" +checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" [[package]] name = "xz2" @@ -3758,11 +3971,31 @@ dependencies = [ "lzma-sys", ] +[[package]] +name = "zerocopy" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.43", +] + [[package]] name = "zeroize" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" [[package]] name = "zstd" @@ -3770,7 +4003,16 @@ version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" dependencies = [ - "zstd-safe", + "zstd-safe 6.0.6", +] + +[[package]] +name = "zstd" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" +dependencies = [ + "zstd-safe 7.0.0", ] [[package]] @@ -3783,13 +4025,21 @@ dependencies = [ "zstd-sys", ] +[[package]] +name = "zstd-safe" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" +dependencies = [ + "zstd-sys", +] + [[package]] name = "zstd-sys" -version = "2.0.8+zstd.1.5.5" +version = "2.0.9+zstd.1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5556e6ee25d32df2586c098bbfa278803692a20d0ab9565e049480d52707ec8c" +checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" dependencies = [ "cc", - "libc", "pkg-config", ] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index a550b487509d..eab7c8e0d1f8 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "datafusion-cli" description = "Command Line Client for DataFusion query engine." -version = "31.0.0" +version = "34.0.0" authors = ["Apache Arrow "] edition = "2021" keywords = ["arrow", "datafusion", "query", "sql"] @@ -29,20 +29,23 @@ rust-version = "1.70" readme = "README.md" [dependencies] -arrow = "47.0.0" +arrow = "49.0.0" async-trait = "0.1.41" aws-config = "0.55" aws-credential-types = "0.55" clap = { version = "3", features = ["derive", "cargo"] } -datafusion = { path = "../datafusion/core", version = "31.0.0" } +datafusion = { path = "../datafusion/core", version = "34.0.0", features = ["avro", "crypto_expressions", "encoding_expressions", "parquet", "regex_expressions", "unicode_expressions", "compression"] } +datafusion-common = { path = "../datafusion/common" } dirs = "4.0.0" env_logger = "0.9" +futures = "0.3" mimalloc = { version = "0.1", default-features = false } -object_store = { version = "0.7.0", features = ["aws", "gcp"] } +object_store = { version = "0.8.0", features = ["aws", "gcp"] } parking_lot = { version = "0.12" } +parquet = { version = "49.0.0", default-features = false } regex = "1.8" rustyline = "11.0" -tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } +tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } url = "2.2" [dev-dependencies] diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 100d7bce440c..2320a8c314cf 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -17,6 +17,12 @@ //! Execution functions +use std::io::prelude::*; +use std::io::BufReader; +use std::time::Instant; +use std::{fs::File, sync::Arc}; + +use crate::print_format::PrintFormat; use crate::{ command::{Command, OutputFormat}, helper::{unescape_input, CliHelper}, @@ -26,20 +32,20 @@ use crate::{ }, print_options::{MaxRows, PrintOptions}, }; + +use datafusion::common::{exec_datafusion_err, plan_datafusion_err}; +use datafusion::datasource::listing::ListingTableUrl; +use datafusion::datasource::physical_plan::is_plan_streaming; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::{CreateExternalTable, DdlStatement, LogicalPlan}; +use datafusion::physical_plan::{collect, execute_stream}; +use datafusion::prelude::SessionContext; use datafusion::sql::{parser::DFParser, sqlparser::dialect::dialect_from_str}; -use datafusion::{ - datasource::listing::ListingTableUrl, - error::{DataFusionError, Result}, - logical_expr::{CreateExternalTable, DdlStatement}, -}; -use datafusion::{logical_expr::LogicalPlan, prelude::SessionContext}; + use object_store::ObjectStore; use rustyline::error::ReadlineError; use rustyline::Editor; -use std::io::prelude::*; -use std::io::BufReader; -use std::time::Instant; -use std::{fs::File, sync::Arc}; +use tokio::signal; use url::Url; /// run and execute SQL statements and commands, against a context with the given print options @@ -124,8 +130,6 @@ pub async fn exec_from_repl( ))); rl.load_history(".history").ok(); - let mut print_options = print_options.clone(); - loop { match rl.readline("❯ ") { Ok(line) if line.starts_with('\\') => { @@ -137,9 +141,7 @@ pub async fn exec_from_repl( Command::OutputFormat(subcommand) => { if let Some(subcommand) = subcommand { if let Ok(command) = subcommand.parse::() { - if let Err(e) = - command.execute(&mut print_options).await - { + if let Err(e) = command.execute(print_options).await { eprintln!("{e}") } } else { @@ -153,7 +155,7 @@ pub async fn exec_from_repl( } } _ => { - if let Err(e) = cmd.execute(ctx, &mut print_options).await { + if let Err(e) = cmd.execute(ctx, print_options).await { eprintln!("{e}") } } @@ -164,9 +166,15 @@ pub async fn exec_from_repl( } Ok(line) => { rl.add_history_entry(line.trim_end())?; - match exec_and_print(ctx, &print_options, line).await { - Ok(_) => {} - Err(err) => eprintln!("{err}"), + tokio::select! { + res = exec_and_print(ctx, print_options, line) => match res { + Ok(_) => {} + Err(err) => eprintln!("{err}"), + }, + _ = signal::ctrl_c() => { + println!("^C"); + continue + }, } // dialect might have changed rl.helper_mut().unwrap().set_dialect( @@ -197,20 +205,19 @@ async fn exec_and_print( sql: String, ) -> Result<()> { let now = Instant::now(); - let sql = unescape_input(&sql)?; let task_ctx = ctx.task_ctx(); let dialect = &task_ctx.session_config().options().sql_parser.dialect; let dialect = dialect_from_str(dialect).ok_or_else(|| { - DataFusionError::Plan(format!( + plan_datafusion_err!( "Unsupported SQL dialect: {dialect}. Available dialects: \ Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ MsSQL, ClickHouse, BigQuery, Ansi." - )) + ) })?; let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { - let plan = ctx.state().statement_to_plan(statement).await?; + let mut plan = ctx.state().statement_to_plan(statement).await?; // For plans like `Explain` ignore `MaxRows` option and always display all rows let should_ignore_maxrows = matches!( @@ -220,25 +227,30 @@ async fn exec_and_print( | LogicalPlan::Analyze(_) ); - let df = match &plan { - LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) => { - create_external_table(ctx, cmd).await?; - ctx.execute_logical_plan(plan).await? - } - _ => ctx.execute_logical_plan(plan).await?, - }; + // Note that cmd is a mutable reference so that create_external_table function can remove all + // datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion + // will raise Configuration errors. + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { + create_external_table(ctx, cmd).await?; + } - let results = df.collect().await?; + let df = ctx.execute_logical_plan(plan).await?; + let physical_plan = df.create_physical_plan().await?; - let print_options = if should_ignore_maxrows { - PrintOptions { - maxrows: MaxRows::Unlimited, - ..print_options.clone() - } + if is_plan_streaming(&physical_plan)? { + let stream = execute_stream(physical_plan, task_ctx.clone())?; + print_options.print_stream(stream, now).await?; } else { - print_options.clone() - }; - print_options.print_batches(&results, now)?; + let mut print_options = print_options.clone(); + if should_ignore_maxrows { + print_options.maxrows = MaxRows::Unlimited; + } + if print_options.format == PrintFormat::Automatic { + print_options.format = PrintFormat::Table; + } + let results = collect(physical_plan, task_ctx.clone()).await?; + print_options.print_batches(&results, now)?; + } } Ok(()) @@ -246,7 +258,7 @@ async fn exec_and_print( async fn create_external_table( ctx: &SessionContext, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result<()> { let table_path = ListingTableUrl::parse(&cmd.location)?; let scheme = table_path.scheme(); @@ -272,10 +284,7 @@ async fn create_external_table( .object_store_registry .get_store(url) .map_err(|_| { - DataFusionError::Execution(format!( - "Unsupported object store scheme: {}", - scheme - )) + exec_datafusion_err!("Unsupported object store scheme: {}", scheme) })? } }; @@ -287,15 +296,32 @@ async fn create_external_table( #[cfg(test)] mod tests { + use std::str::FromStr; + use super::*; use datafusion::common::plan_err; + use datafusion_common::{file_options::StatementOptions, FileTypeWriterOptions}; async fn create_external_table_test(location: &str, sql: &str) -> Result<()> { let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(sql).await?; + let mut plan = ctx.state().create_logical_plan(sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { create_external_table(&ctx, cmd).await?; + let options: Vec<_> = cmd + .options + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + let statement_options = StatementOptions::new(options); + let file_type = + datafusion_common::FileType::from_str(cmd.file_type.as_str())?; + + let _file_type_writer_options = FileTypeWriterOptions::build( + &file_type, + ctx.state().config_options(), + &statement_options, + )?; } else { return plan_err!("LogicalPlan is not a CreateExternalTable"); } @@ -349,7 +375,7 @@ mod tests { async fn create_object_store_table_gcs() -> Result<()> { let service_account_path = "fake_service_account_path"; let service_account_key = - "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\"}"; + "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}"; let application_credentials_path = "fake_application_credentials_path"; let location = "gcs://bucket/path/file.parquet"; @@ -365,8 +391,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('service_account_key' '{service_account_key}') LOCATION '{location}'"); let err = create_external_table_test(location, &sql) .await - .unwrap_err(); - assert!(err.to_string().contains("No RSA key found in pem file")); + .unwrap_err() + .to_string(); + assert!(err.contains("No RSA key found in pem file"), "{err}"); // for application_credentials_path let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET @@ -386,15 +413,7 @@ mod tests { // Ensure that local files are also registered let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'"); - let err = create_external_table_test(location, &sql) - .await - .unwrap_err(); - - if let DataFusionError::IoError(e) = err { - assert_eq!(e.kind(), std::io::ErrorKind::NotFound); - } else { - return Err(err); - } + create_external_table_test(location, &sql).await.unwrap(); Ok(()) } diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index eeebe713d716..5390fa9f2271 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -16,12 +16,27 @@ // under the License. //! Functions that are query-able and searchable via the `\h` command -use arrow::array::StringArray; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::array::{Int64Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; +use async_trait::async_trait; +use datafusion::common::DataFusionError; +use datafusion::common::{plan_err, Column}; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; use datafusion::error::Result; +use datafusion::execution::context::SessionState; +use datafusion::logical_expr::Expr; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::scalar::ScalarValue; +use parquet::basic::ConvertedType; +use parquet::file::reader::FileReader; +use parquet::file::serialized_reader::SerializedFileReader; +use parquet::file::statistics::Statistics; use std::fmt; +use std::fs::File; use std::str::FromStr; use std::sync::Arc; @@ -196,3 +211,232 @@ pub fn display_all_functions() -> Result<()> { println!("{}", pretty_format_batches(&[batch]).unwrap()); Ok(()) } + +/// PARQUET_META table function +struct ParquetMetadataTable { + schema: SchemaRef, + batch: RecordBatch, +} + +#[async_trait] +impl TableProvider for ParquetMetadataTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow::datatypes::SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> datafusion::logical_expr::TableType { + datafusion::logical_expr::TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(MemoryExec::try_new( + &[vec![self.batch.clone()]], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} + +fn convert_parquet_statistics( + value: &Statistics, + converted_type: ConvertedType, +) -> (String, String) { + match (value, converted_type) { + (Statistics::Boolean(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::Int32(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::Int64(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::Int96(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::Float(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::Double(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::ByteArray(val), ConvertedType::UTF8) => { + let min_bytes = val.min(); + let max_bytes = val.max(); + let min = min_bytes + .as_utf8() + .map(|v| v.to_string()) + .unwrap_or_else(|_| min_bytes.to_string()); + + let max = max_bytes + .as_utf8() + .map(|v| v.to_string()) + .unwrap_or_else(|_| max_bytes.to_string()); + (min, max) + } + (Statistics::ByteArray(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::FixedLenByteArray(val), ConvertedType::UTF8) => { + let min_bytes = val.min(); + let max_bytes = val.max(); + let min = min_bytes + .as_utf8() + .map(|v| v.to_string()) + .unwrap_or_else(|_| min_bytes.to_string()); + + let max = max_bytes + .as_utf8() + .map(|v| v.to_string()) + .unwrap_or_else(|_| max_bytes.to_string()); + (min, max) + } + (Statistics::FixedLenByteArray(val), _) => { + (val.min().to_string(), val.max().to_string()) + } + } +} + +pub struct ParquetMetadataFunc {} + +impl TableFunctionImpl for ParquetMetadataFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let filename = match exprs.first() { + Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') + Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") + _ => { + return plan_err!( + "parquet_metadata requires string argument as its input" + ); + } + }; + + let file = File::open(filename.clone())?; + let reader = SerializedFileReader::new(file)?; + let metadata = reader.metadata(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("filename", DataType::Utf8, true), + Field::new("row_group_id", DataType::Int64, true), + Field::new("row_group_num_rows", DataType::Int64, true), + Field::new("row_group_num_columns", DataType::Int64, true), + Field::new("row_group_bytes", DataType::Int64, true), + Field::new("column_id", DataType::Int64, true), + Field::new("file_offset", DataType::Int64, true), + Field::new("num_values", DataType::Int64, true), + Field::new("path_in_schema", DataType::Utf8, true), + Field::new("type", DataType::Utf8, true), + Field::new("stats_min", DataType::Utf8, true), + Field::new("stats_max", DataType::Utf8, true), + Field::new("stats_null_count", DataType::Int64, true), + Field::new("stats_distinct_count", DataType::Int64, true), + Field::new("stats_min_value", DataType::Utf8, true), + Field::new("stats_max_value", DataType::Utf8, true), + Field::new("compression", DataType::Utf8, true), + Field::new("encodings", DataType::Utf8, true), + Field::new("index_page_offset", DataType::Int64, true), + Field::new("dictionary_page_offset", DataType::Int64, true), + Field::new("data_page_offset", DataType::Int64, true), + Field::new("total_compressed_size", DataType::Int64, true), + Field::new("total_uncompressed_size", DataType::Int64, true), + ])); + + // construct recordbatch from metadata + let mut filename_arr = vec![]; + let mut row_group_id_arr = vec![]; + let mut row_group_num_rows_arr = vec![]; + let mut row_group_num_columns_arr = vec![]; + let mut row_group_bytes_arr = vec![]; + let mut column_id_arr = vec![]; + let mut file_offset_arr = vec![]; + let mut num_values_arr = vec![]; + let mut path_in_schema_arr = vec![]; + let mut type_arr = vec![]; + let mut stats_min_arr = vec![]; + let mut stats_max_arr = vec![]; + let mut stats_null_count_arr = vec![]; + let mut stats_distinct_count_arr = vec![]; + let mut stats_min_value_arr = vec![]; + let mut stats_max_value_arr = vec![]; + let mut compression_arr = vec![]; + let mut encodings_arr = vec![]; + let mut index_page_offset_arr = vec![]; + let mut dictionary_page_offset_arr = vec![]; + let mut data_page_offset_arr = vec![]; + let mut total_compressed_size_arr = vec![]; + let mut total_uncompressed_size_arr = vec![]; + for (rg_idx, row_group) in metadata.row_groups().iter().enumerate() { + for (col_idx, column) in row_group.columns().iter().enumerate() { + filename_arr.push(filename.clone()); + row_group_id_arr.push(rg_idx as i64); + row_group_num_rows_arr.push(row_group.num_rows()); + row_group_num_columns_arr.push(row_group.num_columns() as i64); + row_group_bytes_arr.push(row_group.total_byte_size()); + column_id_arr.push(col_idx as i64); + file_offset_arr.push(column.file_offset()); + num_values_arr.push(column.num_values()); + path_in_schema_arr.push(column.column_path().to_string()); + type_arr.push(column.column_type().to_string()); + let converted_type = column.column_descr().converted_type(); + + if let Some(s) = column.statistics() { + let (min_val, max_val) = if s.has_min_max_set() { + let (min_val, max_val) = + convert_parquet_statistics(s, converted_type); + (Some(min_val), Some(max_val)) + } else { + (None, None) + }; + stats_min_arr.push(min_val.clone()); + stats_max_arr.push(max_val.clone()); + stats_null_count_arr.push(Some(s.null_count() as i64)); + stats_distinct_count_arr.push(s.distinct_count().map(|c| c as i64)); + stats_min_value_arr.push(min_val); + stats_max_value_arr.push(max_val); + } else { + stats_min_arr.push(None); + stats_max_arr.push(None); + stats_null_count_arr.push(None); + stats_distinct_count_arr.push(None); + stats_min_value_arr.push(None); + stats_max_value_arr.push(None); + }; + compression_arr.push(format!("{:?}", column.compression())); + encodings_arr.push(format!("{:?}", column.encodings())); + index_page_offset_arr.push(column.index_page_offset()); + dictionary_page_offset_arr.push(column.dictionary_page_offset()); + data_page_offset_arr.push(column.data_page_offset()); + total_compressed_size_arr.push(column.compressed_size()); + total_uncompressed_size_arr.push(column.uncompressed_size()); + } + } + + let rb = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(filename_arr)), + Arc::new(Int64Array::from(row_group_id_arr)), + Arc::new(Int64Array::from(row_group_num_rows_arr)), + Arc::new(Int64Array::from(row_group_num_columns_arr)), + Arc::new(Int64Array::from(row_group_bytes_arr)), + Arc::new(Int64Array::from(column_id_arr)), + Arc::new(Int64Array::from(file_offset_arr)), + Arc::new(Int64Array::from(num_values_arr)), + Arc::new(StringArray::from(path_in_schema_arr)), + Arc::new(StringArray::from(type_arr)), + Arc::new(StringArray::from(stats_min_arr)), + Arc::new(StringArray::from(stats_max_arr)), + Arc::new(Int64Array::from(stats_null_count_arr)), + Arc::new(Int64Array::from(stats_distinct_count_arr)), + Arc::new(StringArray::from(stats_min_value_arr)), + Arc::new(StringArray::from(stats_max_value_arr)), + Arc::new(StringArray::from(compression_arr)), + Arc::new(StringArray::from(encodings_arr)), + Arc::new(Int64Array::from(index_page_offset_arr)), + Arc::new(Int64Array::from(dictionary_page_offset_arr)), + Arc::new(Int64Array::from(data_page_offset_arr)), + Arc::new(Int64Array::from(total_compressed_size_arr)), + Arc::new(Int64Array::from(total_uncompressed_size_arr)), + ], + )?; + + let parquet_metadata = ParquetMetadataTable { schema, batch: rb }; + Ok(Arc::new(parquet_metadata)) + } +} diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 33a1caeb1b5b..563d172f2c95 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -15,25 +15,28 @@ // specific language governing permissions and limitations // under the License. -use clap::Parser; +use std::collections::HashMap; +use std::env; +use std::path::Path; +use std::str::FromStr; +use std::sync::{Arc, OnceLock}; + use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionConfig; use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool}; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::prelude::SessionContext; use datafusion_cli::catalog::DynamicFileCatalog; +use datafusion_cli::functions::ParquetMetadataFunc; use datafusion_cli::{ exec, print_format::PrintFormat, print_options::{MaxRows, PrintOptions}, DATAFUSION_CLI_VERSION, }; + +use clap::Parser; use mimalloc::MiMalloc; -use std::collections::HashMap; -use std::env; -use std::path::Path; -use std::str::FromStr; -use std::sync::{Arc, OnceLock}; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; @@ -110,7 +113,7 @@ struct Args { )] rc: Option>, - #[clap(long, arg_enum, default_value_t = PrintFormat::Table)] + #[clap(long, arg_enum, default_value_t = PrintFormat::Automatic)] format: PrintFormat, #[clap( @@ -178,13 +181,15 @@ pub async fn main() -> Result<()> { let runtime_env = create_runtime_env(rn_config.clone())?; let mut ctx = - SessionContext::with_config_rt(session_config.clone(), Arc::new(runtime_env)); + SessionContext::new_with_config_rt(session_config.clone(), Arc::new(runtime_env)); ctx.refresh_catalogs().await?; // install dynamic catalog provider that knows how to open files ctx.register_catalog_list(Arc::new(DynamicFileCatalog::new( ctx.state().catalog_list(), ctx.state_weak_ref(), ))); + // register `parquet_metadata` table function to get metadata from parquet files + ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); let mut print_options = PrintOptions { format: args.format, @@ -329,6 +334,7 @@ fn extract_memory_pool_size(size: &str) -> Result { #[cfg(test)] mod tests { use super::*; + use datafusion::assert_batches_eq; fn assert_conversion(input: &str, expected: Result) { let result = extract_memory_pool_size(input); @@ -385,4 +391,58 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_parquet_metadata_works() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); + + // input with single quote + let sql = + "SELECT * FROM parquet_metadata('../datafusion/core/tests/data/fixed_size_list_array.parquet')"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + + let excepted = [ + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + "| filename | row_group_id | row_group_num_rows | row_group_num_columns | row_group_bytes | column_id | file_offset | num_values | path_in_schema | type | stats_min | stats_max | stats_null_count | stats_distinct_count | stats_min_value | stats_max_value | compression | encodings | index_page_offset | dictionary_page_offset | data_page_offset | total_compressed_size | total_uncompressed_size |", + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + "| ../datafusion/core/tests/data/fixed_size_list_array.parquet | 0 | 2 | 1 | 123 | 0 | 125 | 4 | \"f0.list.item\" | INT64 | 1 | 4 | 0 | | 1 | 4 | SNAPPY | [RLE_DICTIONARY, PLAIN, RLE] | | 4 | 46 | 121 | 123 |", + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + ]; + assert_batches_eq!(excepted, &rbs); + + // input with double quote + let sql = + "SELECT * FROM parquet_metadata(\"../datafusion/core/tests/data/fixed_size_list_array.parquet\")"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + assert_batches_eq!(excepted, &rbs); + + Ok(()) + } + + #[tokio::test] + async fn test_parquet_metadata_works_with_strings() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); + + // input with string columns + let sql = + "SELECT * FROM parquet_metadata('../parquet-testing/data/data_index_bloom_encoding_stats.parquet')"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + + let excepted = [ + +"+-----------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+------------+-----------+-----------+------------------+----------------------+-----------------+-----------------+--------------------+--------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", +"| filename | row_group_id | row_group_num_rows | row_group_num_columns | row_group_bytes | column_id | file_offset | num_values | path_in_schema | type | stats_min | stats_max | stats_null_count | stats_distinct_count | stats_min_value | stats_max_value | compression | encodings | index_page_offset | dictionary_page_offset | data_page_offset | total_compressed_size | total_uncompressed_size |", +"+-----------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+------------+-----------+-----------+------------------+----------------------+-----------------+-----------------+--------------------+--------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", +"| ../parquet-testing/data/data_index_bloom_encoding_stats.parquet | 0 | 14 | 1 | 163 | 0 | 4 | 14 | \"String\" | BYTE_ARRAY | Hello | today | 0 | | Hello | today | GZIP(GzipLevel(6)) | [BIT_PACKED, RLE, PLAIN] | | | 4 | 152 | 163 |", +"+-----------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+------------+-----------+-----------+------------------+----------------------+-----------------+-----------------+--------------------+--------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+" + ]; + assert_batches_eq!(excepted, &rbs); + + Ok(()) + } } diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index c39d1915eb43..9d79c7e0ec78 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -30,20 +30,23 @@ use url::Url; pub async fn get_s3_object_store_builder( url: &Url, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = AmazonS3Builder::from_env().with_bucket_name(bucket_name); if let (Some(access_key_id), Some(secret_access_key)) = ( - cmd.options.get("access_key_id"), - cmd.options.get("secret_access_key"), + // These options are datafusion-cli specific and must be removed before passing through to datafusion. + // Otherwise, a Configuration error will be raised. + cmd.options.remove("access_key_id"), + cmd.options.remove("secret_access_key"), ) { + println!("removing secret access key!"); builder = builder .with_access_key_id(access_key_id) .with_secret_access_key(secret_access_key); - if let Some(session_token) = cmd.options.get("session_token") { + if let Some(session_token) = cmd.options.remove("session_token") { builder = builder.with_token(session_token); } } else { @@ -66,7 +69,7 @@ pub async fn get_s3_object_store_builder( builder = builder.with_credentials(credentials); } - if let Some(region) = cmd.options.get("region") { + if let Some(region) = cmd.options.remove("region") { builder = builder.with_region(region); } @@ -99,7 +102,7 @@ impl CredentialProvider for S3CredentialProvider { pub fn get_oss_object_store_builder( url: &Url, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = AmazonS3Builder::from_env() @@ -109,15 +112,15 @@ pub fn get_oss_object_store_builder( .with_region("do_not_care"); if let (Some(access_key_id), Some(secret_access_key)) = ( - cmd.options.get("access_key_id"), - cmd.options.get("secret_access_key"), + cmd.options.remove("access_key_id"), + cmd.options.remove("secret_access_key"), ) { builder = builder .with_access_key_id(access_key_id) .with_secret_access_key(secret_access_key); } - if let Some(endpoint) = cmd.options.get("endpoint") { + if let Some(endpoint) = cmd.options.remove("endpoint") { builder = builder.with_endpoint(endpoint); } @@ -126,21 +129,21 @@ pub fn get_oss_object_store_builder( pub fn get_gcs_object_store_builder( url: &Url, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = GoogleCloudStorageBuilder::from_env().with_bucket_name(bucket_name); - if let Some(service_account_path) = cmd.options.get("service_account_path") { + if let Some(service_account_path) = cmd.options.remove("service_account_path") { builder = builder.with_service_account_path(service_account_path); } - if let Some(service_account_key) = cmd.options.get("service_account_key") { + if let Some(service_account_key) = cmd.options.remove("service_account_key") { builder = builder.with_service_account_key(service_account_key); } if let Some(application_credentials_path) = - cmd.options.get("application_credentials_path") + cmd.options.remove("application_credentials_path") { builder = builder.with_application_credentials(application_credentials_path); } @@ -180,9 +183,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('access_key_id' '{access_key_id}', 'secret_access_key' '{secret_access_key}', 'region' '{region}', 'session_token' {session_token}) LOCATION '{location}'"); let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { let builder = get_s3_object_store_builder(table_url.as_ref(), cmd).await?; // get the actual configuration information, then assert_eq! let config = [ @@ -212,9 +215,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('access_key_id' '{access_key_id}', 'secret_access_key' '{secret_access_key}', 'endpoint' '{endpoint}') LOCATION '{location}'"); let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { let builder = get_oss_object_store_builder(table_url.as_ref(), cmd)?; // get the actual configuration information, then assert_eq! let config = [ @@ -244,9 +247,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('service_account_path' '{service_account_path}', 'service_account_key' '{service_account_key}', 'application_credentials_path' '{application_credentials_path}') LOCATION '{location}'"); let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { let builder = get_gcs_object_store_builder(table_url.as_ref(), cmd)?; // get the actual configuration information, then assert_eq! let config = [ diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index e2994bc14034..ea418562495d 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -16,23 +16,27 @@ // under the License. //! Print format variants + +use std::str::FromStr; + use crate::print_options::MaxRows; + use arrow::csv::writer::WriterBuilder; use arrow::json::{ArrayWriter, LineDelimitedWriter}; +use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches_with_options; -use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::format::DEFAULT_FORMAT_OPTIONS; -use datafusion::error::{DataFusionError, Result}; -use std::str::FromStr; +use datafusion::error::Result; /// Allow records to be printed in different formats -#[derive(Debug, PartialEq, Eq, clap::ArgEnum, Clone)] +#[derive(Debug, PartialEq, Eq, clap::ArgEnum, Clone, Copy)] pub enum PrintFormat { Csv, Tsv, Table, Json, NdJson, + Automatic, } impl FromStr for PrintFormat { @@ -44,31 +48,44 @@ impl FromStr for PrintFormat { } macro_rules! batches_to_json { - ($WRITER: ident, $batches: expr) => {{ - let mut bytes = vec![]; + ($WRITER: ident, $writer: expr, $batches: expr) => {{ { - let mut writer = $WRITER::new(&mut bytes); - $batches.iter().try_for_each(|batch| writer.write(batch))?; - writer.finish()?; + if !$batches.is_empty() { + let mut json_writer = $WRITER::new(&mut *$writer); + for batch in $batches { + json_writer.write(batch)?; + } + json_writer.finish()?; + json_finish!($WRITER, $writer); + } } - String::from_utf8(bytes).map_err(|e| DataFusionError::External(Box::new(e)))? + Ok(()) as Result<()> }}; } -fn print_batches_with_sep(batches: &[RecordBatch], delimiter: u8) -> Result { - let mut bytes = vec![]; - { - let builder = WriterBuilder::new() - .has_headers(true) - .with_delimiter(delimiter); - let mut writer = builder.build(&mut bytes); - for batch in batches { - writer.write(batch)?; - } +macro_rules! json_finish { + (ArrayWriter, $writer: expr) => {{ + writeln!($writer)?; + }}; + (LineDelimitedWriter, $writer: expr) => {{}}; +} + +fn print_batches_with_sep( + writer: &mut W, + batches: &[RecordBatch], + delimiter: u8, + with_header: bool, +) -> Result<()> { + let builder = WriterBuilder::new() + .with_header(with_header) + .with_delimiter(delimiter); + let mut csv_writer = builder.build(writer); + + for batch in batches { + csv_writer.write(batch)?; } - let formatted = - String::from_utf8(bytes).map_err(|e| DataFusionError::External(Box::new(e)))?; - Ok(formatted) + + Ok(()) } fn keep_only_maxrows(s: &str, maxrows: usize) -> String { @@ -88,97 +105,118 @@ fn keep_only_maxrows(s: &str, maxrows: usize) -> String { result.join("\n") } -fn format_batches_with_maxrows( +fn format_batches_with_maxrows( + writer: &mut W, batches: &[RecordBatch], maxrows: MaxRows, -) -> Result { +) -> Result<()> { match maxrows { MaxRows::Limited(maxrows) => { - // Only format enough batches for maxrows + // Filter batches to meet the maxrows condition let mut filtered_batches = Vec::new(); - let mut batches = batches; - let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); - if row_count > maxrows { - let mut accumulated_rows = 0; - - for batch in batches { + let mut row_count: usize = 0; + let mut over_limit = false; + for batch in batches { + if row_count + batch.num_rows() > maxrows { + // If adding this batch exceeds maxrows, slice the batch + let limit = maxrows - row_count; + let sliced_batch = batch.slice(0, limit); + filtered_batches.push(sliced_batch); + over_limit = true; + break; + } else { filtered_batches.push(batch.clone()); - if accumulated_rows + batch.num_rows() > maxrows { - break; - } - accumulated_rows += batch.num_rows(); + row_count += batch.num_rows(); } - - batches = &filtered_batches; } - let mut formatted = format!( - "{}", - pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)?, - ); - - if row_count > maxrows { - formatted = keep_only_maxrows(&formatted, maxrows); + let formatted = pretty_format_batches_with_options( + &filtered_batches, + &DEFAULT_FORMAT_OPTIONS, + )?; + if over_limit { + let mut formatted_str = format!("{}", formatted); + formatted_str = keep_only_maxrows(&formatted_str, maxrows); + writeln!(writer, "{}", formatted_str)?; + } else { + writeln!(writer, "{}", formatted)?; } - - Ok(formatted) } MaxRows::Unlimited => { - // maxrows not specified, print all rows - Ok(format!( - "{}", - pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)?, - )) + let formatted = + pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)?; + writeln!(writer, "{}", formatted)?; } } + + Ok(()) } impl PrintFormat { - /// print the batches to stdout using the specified format - /// `maxrows` option is only used for `Table` format: - /// If `maxrows` is Some(n), then at most n rows will be displayed - /// If `maxrows` is None, then every row will be displayed - pub fn print_batches(&self, batches: &[RecordBatch], maxrows: MaxRows) -> Result<()> { - if batches.is_empty() { + /// Print the batches to a writer using the specified format + pub fn print_batches( + &self, + writer: &mut W, + batches: &[RecordBatch], + maxrows: MaxRows, + with_header: bool, + ) -> Result<()> { + if batches.is_empty() || batches[0].num_rows() == 0 { return Ok(()); } match self { - Self::Csv => println!("{}", print_batches_with_sep(batches, b',')?), - Self::Tsv => println!("{}", print_batches_with_sep(batches, b'\t')?), + Self::Csv | Self::Automatic => { + print_batches_with_sep(writer, batches, b',', with_header) + } + Self::Tsv => print_batches_with_sep(writer, batches, b'\t', with_header), Self::Table => { if maxrows == MaxRows::Limited(0) { return Ok(()); } - println!("{}", format_batches_with_maxrows(batches, maxrows)?,) - } - Self::Json => println!("{}", batches_to_json!(ArrayWriter, batches)), - Self::NdJson => { - println!("{}", batches_to_json!(LineDelimitedWriter, batches)) + format_batches_with_maxrows(writer, batches, maxrows) } + Self::Json => batches_to_json!(ArrayWriter, writer, batches), + Self::NdJson => batches_to_json!(LineDelimitedWriter, writer, batches), } - Ok(()) } } #[cfg(test)] mod tests { + use std::io::{Cursor, Read, Write}; + use std::sync::Arc; + use super::*; + use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; - use std::sync::Arc; + use datafusion::error::Result; + + fn run_test(batches: &[RecordBatch], test_fn: F) -> Result + where + F: Fn(&mut Cursor>, &[RecordBatch]) -> Result<()>, + { + let mut buffer = Cursor::new(Vec::new()); + test_fn(&mut buffer, batches)?; + buffer.set_position(0); + let mut contents = String::new(); + buffer.read_to_string(&mut contents)?; + Ok(contents) + } #[test] - fn test_print_batches_with_sep() { - let batches = vec![]; - assert_eq!("", print_batches_with_sep(&batches, b',').unwrap()); + fn test_print_batches_with_sep() -> Result<()> { + let contents = run_test(&[], |buffer, batches| { + print_batches_with_sep(buffer, batches, b',', true) + })?; + assert_eq!(contents, ""); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), Field::new("c", DataType::Int32, false), ])); - let batch = RecordBatch::try_new( schema, vec![ @@ -186,29 +224,33 @@ mod tests { Arc::new(Int32Array::from(vec![4, 5, 6])), Arc::new(Int32Array::from(vec![7, 8, 9])), ], - ) - .unwrap(); + )?; - let batches = vec![batch]; - let r = print_batches_with_sep(&batches, b',').unwrap(); - assert_eq!("a,b,c\n1,4,7\n2,5,8\n3,6,9\n", r); + let contents = run_test(&[batch], |buffer, batches| { + print_batches_with_sep(buffer, batches, b',', true) + })?; + assert_eq!(contents, "a,b,c\n1,4,7\n2,5,8\n3,6,9\n"); + + Ok(()) } #[test] fn test_print_batches_to_json_empty() -> Result<()> { - let batches = vec![]; - let r = batches_to_json!(ArrayWriter, &batches); - assert_eq!("", r); + let contents = run_test(&[], |buffer, batches| { + batches_to_json!(ArrayWriter, buffer, batches) + })?; + assert_eq!(contents, ""); - let r = batches_to_json!(LineDelimitedWriter, &batches); - assert_eq!("", r); + let contents = run_test(&[], |buffer, batches| { + batches_to_json!(LineDelimitedWriter, buffer, batches) + })?; + assert_eq!(contents, ""); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), Field::new("c", DataType::Int32, false), ])); - let batch = RecordBatch::try_new( schema, vec![ @@ -216,25 +258,29 @@ mod tests { Arc::new(Int32Array::from(vec![4, 5, 6])), Arc::new(Int32Array::from(vec![7, 8, 9])), ], - ) - .unwrap(); - + )?; let batches = vec![batch]; - let r = batches_to_json!(ArrayWriter, &batches); - assert_eq!("[{\"a\":1,\"b\":4,\"c\":7},{\"a\":2,\"b\":5,\"c\":8},{\"a\":3,\"b\":6,\"c\":9}]", r); - let r = batches_to_json!(LineDelimitedWriter, &batches); - assert_eq!("{\"a\":1,\"b\":4,\"c\":7}\n{\"a\":2,\"b\":5,\"c\":8}\n{\"a\":3,\"b\":6,\"c\":9}\n", r); + let contents = run_test(&batches, |buffer, batches| { + batches_to_json!(ArrayWriter, buffer, batches) + })?; + assert_eq!(contents, "[{\"a\":1,\"b\":4,\"c\":7},{\"a\":2,\"b\":5,\"c\":8},{\"a\":3,\"b\":6,\"c\":9}]\n"); + + let contents = run_test(&batches, |buffer, batches| { + batches_to_json!(LineDelimitedWriter, buffer, batches) + })?; + assert_eq!(contents, "{\"a\":1,\"b\":4,\"c\":7}\n{\"a\":2,\"b\":5,\"c\":8}\n{\"a\":3,\"b\":6,\"c\":9}\n"); + Ok(()) } #[test] fn test_format_batches_with_maxrows() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - - let batch = - RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]) - .unwrap(); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; #[rustfmt::skip] let all_rows_expected = [ @@ -244,7 +290,7 @@ mod tests { "| 1 |", "| 2 |", "| 3 |", - "+---+", + "+---+\n", ].join("\n"); #[rustfmt::skip] @@ -256,7 +302,7 @@ mod tests { "| . |", "| . |", "| . |", - "+---+", + "+---+\n", ].join("\n"); #[rustfmt::skip] @@ -272,26 +318,36 @@ mod tests { "| . |", "| . |", "| . |", - "+---+", + "+---+\n", ].join("\n"); - let no_limit = format_batches_with_maxrows(&[batch.clone()], MaxRows::Unlimited)?; - assert_eq!(all_rows_expected, no_limit); - - let maxrows_less_than_actual = - format_batches_with_maxrows(&[batch.clone()], MaxRows::Limited(1))?; - assert_eq!(one_row_expected, maxrows_less_than_actual); - let maxrows_more_than_actual = - format_batches_with_maxrows(&[batch.clone()], MaxRows::Limited(5))?; - assert_eq!(all_rows_expected, maxrows_more_than_actual); - let maxrows_equals_actual = - format_batches_with_maxrows(&[batch.clone()], MaxRows::Limited(3))?; - assert_eq!(all_rows_expected, maxrows_equals_actual); - let multi_batches = format_batches_with_maxrows( + let no_limit = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Unlimited) + })?; + assert_eq!(no_limit, all_rows_expected); + + let maxrows_less_than_actual = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(1)) + })?; + assert_eq!(maxrows_less_than_actual, one_row_expected); + + let maxrows_more_than_actual = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(5)) + })?; + assert_eq!(maxrows_more_than_actual, all_rows_expected); + + let maxrows_equals_actual = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(3)) + })?; + assert_eq!(maxrows_equals_actual, all_rows_expected); + + let multi_batches = run_test( &[batch.clone(), batch.clone(), batch.clone()], - MaxRows::Limited(5), + |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(5)) + }, )?; - assert_eq!(multi_batches_expected, multi_batches); + assert_eq!(multi_batches, multi_batches_expected); Ok(()) } diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index 0a6c8d4c36fc..b382eb34f62c 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -15,13 +15,21 @@ // specific language governing permissions and limitations // under the License. -use crate::print_format::PrintFormat; -use datafusion::arrow::record_batch::RecordBatch; -use datafusion::error::Result; use std::fmt::{Display, Formatter}; +use std::io::Write; +use std::pin::Pin; use std::str::FromStr; use std::time::Instant; +use crate::print_format::PrintFormat; + +use arrow::record_batch::RecordBatch; +use datafusion::common::DataFusionError; +use datafusion::error::Result; +use datafusion::physical_plan::RecordBatchStream; + +use futures::StreamExt; + #[derive(Debug, Clone, PartialEq, Copy)] pub enum MaxRows { /// show all rows in the output @@ -85,20 +93,71 @@ fn get_timing_info_str( } impl PrintOptions { - /// print the batches to stdout using the specified format + /// Print the batches to stdout using the specified format pub fn print_batches( &self, batches: &[RecordBatch], query_start_time: Instant, ) -> Result<()> { + let stdout = std::io::stdout(); + let mut writer = stdout.lock(); + + self.format + .print_batches(&mut writer, batches, self.maxrows, true)?; + let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); - // Elapsed time should not count time for printing batches - let timing_info = get_timing_info_str(row_count, self.maxrows, query_start_time); + let timing_info = get_timing_info_str( + row_count, + if self.format == PrintFormat::Table { + self.maxrows + } else { + MaxRows::Unlimited + }, + query_start_time, + ); + + if !self.quiet { + writeln!(writer, "{timing_info}")?; + } + + Ok(()) + } + + /// Print the stream to stdout using the specified format + pub async fn print_stream( + &self, + mut stream: Pin>, + query_start_time: Instant, + ) -> Result<()> { + if self.format == PrintFormat::Table { + return Err(DataFusionError::External( + "PrintFormat::Table is not implemented".to_string().into(), + )); + }; + + let stdout = std::io::stdout(); + let mut writer = stdout.lock(); + + let mut row_count = 0_usize; + let mut with_header = true; + + while let Some(maybe_batch) = stream.next().await { + let batch = maybe_batch?; + row_count += batch.num_rows(); + self.format.print_batches( + &mut writer, + &[batch], + MaxRows::Unlimited, + with_header, + )?; + with_header = false; + } - self.format.print_batches(batches, self.maxrows)?; + let timing_info = + get_timing_info_str(row_count, MaxRows::Unlimited, query_start_time); if !self.quiet { - println!("{timing_info}"); + writeln!(writer, "{timing_info}")?; } Ok(()) diff --git a/datafusion-cli/tests/cli_integration.rs b/datafusion-cli/tests/cli_integration.rs index 28344ffa94f8..119a0aa39d3c 100644 --- a/datafusion-cli/tests/cli_integration.rs +++ b/datafusion-cli/tests/cli_integration.rs @@ -43,7 +43,7 @@ fn init() { )] #[case::set_batch_size( ["--command", "show datafusion.execution.batch_size", "--format", "json", "-q", "-b", "1"], - "[{\"name\":\"datafusion.execution.batch_size\",\"setting\":\"1\"}]\n" + "[{\"name\":\"datafusion.execution.batch_size\",\"value\":\"1\"}]\n" )] #[test] fn cli_quick_test<'a>( diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index e5146c7fd94e..59580bcb6a05 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -20,9 +20,9 @@ name = "datafusion-examples" description = "DataFusion usage examples" keywords = ["arrow", "query", "sql"] publish = false +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -33,26 +33,26 @@ rust-version = { workspace = true } arrow = { workspace = true } arrow-flight = { workspace = true } arrow-schema = { workspace = true } -async-trait = "0.1.41" -bytes = "1.4" -dashmap = "5.4" -datafusion = { path = "../datafusion/core" } +async-trait = { workspace = true } +bytes = { workspace = true } +dashmap = { workspace = true } +datafusion = { path = "../datafusion/core", features = ["avro"] } datafusion-common = { path = "../datafusion/common" } datafusion-expr = { path = "../datafusion/expr" } datafusion-optimizer = { path = "../datafusion/optimizer" } datafusion-sql = { path = "../datafusion/sql" } -env_logger = "0.10" -futures = "0.3" -log = "0.4" +env_logger = { workspace = true } +futures = { workspace = true } +log = { workspace = true } mimalloc = { version = "0.1", default-features = false } -num_cpus = "1.13.0" -object_store = { version = "0.7.0", features = ["aws", "http"] } +num_cpus = { workspace = true } +object_store = { workspace = true, features = ["aws", "http"] } prost = { version = "0.12", default-features = false } -prost-derive = { version = "0.11", default-features = false } +prost-derive = { version = "0.12", default-features = false } serde = { version = "1.0.136", features = ["derive"] } -serde_json = "1.0.82" -tempfile = "3" +serde_json = { workspace = true } +tempfile = { workspace = true } tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } tonic = "0.10" -url = "2.2" +url = { workspace = true } uuid = "1.2" diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index bfed3976c946..aae451add9e7 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -44,22 +44,27 @@ cargo run --example csv_sql - [`avro_sql.rs`](examples/avro_sql.rs): Build and run a query plan from a SQL statement against a local AVRO file - [`csv_sql.rs`](examples/csv_sql.rs): Build and run a query plan from a SQL statement against a local CSV file +- [`catalog.rs`](examples/external_dependency/catalog.rs): Register the table into a custom catalog - [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) - [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame against a local parquet file +- [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 and writing back to s3 +- [`dataframe_output.rs`](examples/dataframe_output.rs): Examples of methods which write data out from a DataFrame - [`dataframe_in_memory.rs`](examples/dataframe_in_memory.rs): Run a query using a DataFrame against data in memory - [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde -- [`expr_api.rs`](examples/expr_api.rs): Use the `Expr` construction and simplification API -- [`flight_sql_server.rs`](examples/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients +- [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify and anaylze `Expr`s +- [`flight_sql_server.rs`](examples/flight/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients - [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es - [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from a SQL statement against a local Parquet file - [`parquet_sql_multiple_files.rs`](examples/parquet_sql_multiple_files.rs): Build and run a query plan from a SQL statement against multiple local Parquet files -- [`query-aws-s3.rs`](examples/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 +- [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 - [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP - [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass +- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) +- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) -- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined (scalar) Function (UDF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) +- [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) ## Distributed -- [`flight_client.rs`](examples/flight_client.rs) and [`flight_server.rs`](examples/flight_server.rs): Run DataFusion as a standalone process and execute SQL queries from a client using the Flight protocol. +- [`flight_client.rs`](examples/flight/flight_client.rs) and [`flight_server.rs`](examples/flight/flight_server.rs): Run DataFusion as a standalone process and execute SQL queries from a client using the Flight protocol. diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs new file mode 100644 index 000000000000..d530b9abe030 --- /dev/null +++ b/datafusion-examples/examples/advanced_udf.rs @@ -0,0 +1,244 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{ + arrow::{ + array::{ArrayRef, Float32Array, Float64Array}, + datatypes::DataType, + record_batch::RecordBatch, + }, + logical_expr::Volatility, +}; +use std::any::Any; + +use arrow::array::{new_null_array, Array, AsArray}; +use arrow::compute; +use arrow::datatypes::Float64Type; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::{internal_err, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature}; +use std::sync::Arc; + +/// This example shows how to use the full ScalarUDFImpl API to implement a user +/// defined function. As in the `simple_udf.rs` example, this struct implements +/// a function that takes two arguments and returns the first argument raised to +/// the power of the second argument `a^b`. +/// +/// To do so, we must implement the `ScalarUDFImpl` trait. +#[derive(Debug, Clone)] +struct PowUdf { + signature: Signature, + aliases: Vec, +} + +impl PowUdf { + /// Create a new instance of the `PowUdf` struct + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take two arguments of type f64 + vec![DataType::Float64, DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + // we will also add an alias of "my_pow" + aliases: vec!["my_pow".to_string()], + } + } +} + +impl ScalarUDFImpl for PowUdf { + /// We implement as_any so that we can downcast the ScalarUDFImpl trait object + fn as_any(&self) -> &dyn Any { + self + } + + /// Return the name of this function + fn name(&self) -> &str { + "pow" + } + + /// Return the "signature" of this function -- namely what types of arguments it will take + fn signature(&self) -> &Signature { + &self.signature + } + + /// What is the type of value that will be returned by this function? In + /// this case it will always be a constant value, but it could also be a + /// function of the input types. + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + /// This is the function that actually calculates the results. + /// + /// This is the same way that functions built into DataFusion are invoked, + /// which permits important special cases when one or both of the arguments + /// are single values (constants). For example `pow(a, 2)` + /// + /// However, it also means the implementation is more complex than when + /// using `create_udf`. + fn invoke(&self, args: &[ColumnarValue]) -> Result { + // DataFusion has arranged for the correct inputs to be passed to this + // function, but we check again to make sure + assert_eq!(args.len(), 2); + let (base, exp) = (&args[0], &args[1]); + assert_eq!(base.data_type(), DataType::Float64); + assert_eq!(exp.data_type(), DataType::Float64); + + match (base, exp) { + // For demonstration purposes we also implement the scalar / scalar + // case here, but it is not typically required for high performance. + // + // For performance it is most important to optimize cases where at + // least one argument is an array. If all arguments are constants, + // the DataFusion expression simplification logic will often invoke + // this path once during planning, and simply use the result during + // execution. + ( + ColumnarValue::Scalar(ScalarValue::Float64(base)), + ColumnarValue::Scalar(ScalarValue::Float64(exp)), + ) => { + // compute the output. Note DataFusion treats `None` as NULL. + let res = match (base, exp) { + (Some(base), Some(exp)) => Some(base.powf(*exp)), + // one or both arguments were NULL + _ => None, + }; + Ok(ColumnarValue::Scalar(ScalarValue::from(res))) + } + // special case if the exponent is a constant + ( + ColumnarValue::Array(base_array), + ColumnarValue::Scalar(ScalarValue::Float64(exp)), + ) => { + let result_array = match exp { + // a ^ null = null + None => new_null_array(base_array.data_type(), base_array.len()), + // a ^ exp + Some(exp) => { + // DataFusion has ensured both arguments are Float64: + let base_array = base_array.as_primitive::(); + // calculate the result for every row. The `unary` + // kernel creates very fast "vectorized" code and + // handles things like null values for us. + let res: Float64Array = + compute::unary(base_array, |base| base.powf(*exp)); + Arc::new(res) + } + }; + Ok(ColumnarValue::Array(result_array)) + } + + // special case if the base is a constant (note this code is quite + // similar to the previous case, so we omit comments) + ( + ColumnarValue::Scalar(ScalarValue::Float64(base)), + ColumnarValue::Array(exp_array), + ) => { + let res = match base { + None => new_null_array(exp_array.data_type(), exp_array.len()), + Some(base) => { + let exp_array = exp_array.as_primitive::(); + let res: Float64Array = + compute::unary(exp_array, |exp| base.powf(exp)); + Arc::new(res) + } + }; + Ok(ColumnarValue::Array(res)) + } + // Both arguments are arrays so we have to perform the calculation for every row + (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => { + let res: Float64Array = compute::binary( + base_array.as_primitive::(), + exp_array.as_primitive::(), + |base, exp| base.powf(exp), + )?; + Ok(ColumnarValue::Array(Arc::new(res))) + } + // if the types were not float, it is a bug in DataFusion + _ => { + use datafusion_common::DataFusionError; + internal_err!("Invalid argument types to pow function") + } + } + } + + /// We will also add an alias of "my_pow" + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// In this example we register `PowUdf` as a user defined function +/// and invoke it via the DataFrame API and SQL +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context()?; + + // create the UDF + let pow = ScalarUDF::from(PowUdf::new()); + + // register the UDF with the context so it can be invoked by name and from SQL + ctx.register_udf(pow.clone()); + + // get a DataFrame from the context for scanning the "t" table + let df = ctx.table("t").await?; + + // Call pow(a, 10) using the DataFrame API + let df = df.select(vec![pow.call(vec![col("a"), lit(10i32)])])?; + + // note that the second argument is passed as an i32, not f64. DataFusion + // automatically coerces the types to match the UDF's defined signature. + + // print the results + df.show().await?; + + // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL + let sql_df = ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t").await?; + sql_df.show().await?; + + Ok(()) +} + +/// create local execution context with an in-memory table: +/// +/// ```text +/// +-----+-----+ +/// | a | b | +/// +-----+-----+ +/// | 2.1 | 1.0 | +/// | 3.1 | 2.0 | +/// | 4.1 | 3.0 | +/// | 5.1 | 4.0 | +/// +-----+-----+ +/// ``` +fn create_context() -> Result { + // define data. + let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); + let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; + + // declare a new context. In Spark API, this corresponds to a new SparkSession + let ctx = SessionContext::new(); + + // declare a table in memory. In Spark API, this corresponds to createDataFrame(...). + ctx.register_batch("t", batch)?; + Ok(ctx) +} diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs new file mode 100644 index 000000000000..f46031434fc9 --- /dev/null +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -0,0 +1,231 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use std::any::Any; + +use arrow::{ + array::{ArrayRef, AsArray, Float64Array}, + datatypes::Float64Type, +}; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::ScalarValue; +use datafusion_expr::{ + PartitionEvaluator, Signature, WindowFrame, WindowUDF, WindowUDFImpl, +}; + +/// This example shows how to use the full WindowUDFImpl API to implement a user +/// defined window function. As in the `simple_udwf.rs` example, this struct implements +/// a function `partition_evaluator` that returns the `MyPartitionEvaluator` instance. +/// +/// To do so, we must implement the `WindowUDFImpl` trait. +#[derive(Debug, Clone)] +struct SmoothItUdf { + signature: Signature, +} + +impl SmoothItUdf { + /// Create a new instance of the SmoothItUdf struct + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take one arguments of type f64 + vec![DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + } + } +} + +impl WindowUDFImpl for SmoothItUdf { + /// We implement as_any so that we can downcast the WindowUDFImpl trait object + fn as_any(&self) -> &dyn Any { + self + } + + /// Return the name of this function + fn name(&self) -> &str { + "smooth_it" + } + + /// Return the "signature" of this function -- namely that types of arguments it will take + fn signature(&self) -> &Signature { + &self.signature + } + + /// What is the type of value that will be returned by this function. + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + /// Create a `PartitionEvalutor` to evaluate this function on a new + /// partition. + fn partition_evaluator(&self) -> Result> { + Ok(Box::new(MyPartitionEvaluator::new())) + } +} + +/// This implements the lowest level evaluation for a window function +/// +/// It handles calculating the value of the window function for each +/// distinct values of `PARTITION BY` (each car type in our example) +#[derive(Clone, Debug)] +struct MyPartitionEvaluator {} + +impl MyPartitionEvaluator { + fn new() -> Self { + Self {} + } +} + +/// Different evaluation methods are called depending on the various +/// settings of WindowUDF. This example uses the simplest and most +/// general, `evaluate`. See `PartitionEvaluator` for the other more +/// advanced uses. +impl PartitionEvaluator for MyPartitionEvaluator { + /// Tell DataFusion the window function varies based on the value + /// of the window frame. + fn uses_window_frame(&self) -> bool { + true + } + + /// This function is called once per input row. + /// + /// `range`specifies which indexes of `values` should be + /// considered for the calculation. + /// + /// Note this is the SLOWEST, but simplest, way to evaluate a + /// window function. It is much faster to implement + /// evaluate_all or evaluate_all_with_rank, if possible + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &std::ops::Range, + ) -> Result { + // Again, the input argument is an array of floating + // point numbers to calculate a moving average + let arr: &Float64Array = values[0].as_ref().as_primitive::(); + + let range_len = range.end - range.start; + + // our smoothing function will average all the values in the + let output = if range_len > 0 { + let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); + Some(sum / range_len as f64) + } else { + None + }; + + Ok(ScalarValue::Float64(output)) + } +} + +// create local execution context with `cars.csv` registered as a table named `cars` +async fn create_context() -> Result { + // declare a new context. In spark API, this corresponds to a new spark SQL session + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + println!("pwd: {}", std::env::current_dir().unwrap().display()); + let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); + let read_options = CsvReadOptions::default().has_header(true); + + ctx.register_csv("cars", &csv_path, read_options).await?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context().await?; + let smooth_it = WindowUDF::from(SmoothItUdf::new()); + ctx.register_udwf(smooth_it.clone()); + + // Use SQL to run the new window function + let df = ctx.sql("SELECT * from cars").await?; + // print the results + df.show().await?; + + // Use SQL to run the new window function: + // + // `PARTITION BY car`:each distinct value of car (red, and green) + // should be treated as a separate partition (and will result in + // creating a new `PartitionEvaluator`) + // + // `ORDER BY time`: within each partition ('green' or 'red') the + // rows will be be ordered by the value in the `time` column + // + // `evaluate_inside_range` is invoked with a window defined by the + // SQL. In this case: + // + // The first invocation will be passed row 0, the first row in the + // partition. + // + // The second invocation will be passed rows 0 and 1, the first + // two rows in the partition. + // + // etc. + let df = ctx + .sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; + // print the results + df.show().await?; + + // this time, call the new widow function with an explicit + // window so evaluate will be invoked with each window. + // + // `ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING`: each invocation + // sees at most 3 rows: the row before, the current row, and the 1 + // row afterward. + let df = ctx.sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ).await?; + // print the results + df.show().await?; + + // Now, run the function using the DataFrame API: + let window_expr = smooth_it.call( + vec![col("speed")], // smooth_it(speed) + vec![col("car")], // PARTITION BY car + vec![col("time").sort(true, true)], // ORDER BY time ASC + WindowFrame::new(false), + ); + let df = ctx.table("cars").await?.window(vec![window_expr])?; + + // print the results + df.show().await?; + + Ok(()) +} diff --git a/datafusion-examples/examples/csv_opener.rs b/datafusion-examples/examples/csv_opener.rs index 5126666d5e73..96753c8c5260 100644 --- a/datafusion-examples/examples/csv_opener.rs +++ b/datafusion-examples/examples/csv_opener.rs @@ -17,6 +17,7 @@ use std::{sync::Arc, vec}; +use datafusion::common::Statistics; use datafusion::{ assert_batches_eq, datasource::{ @@ -29,6 +30,7 @@ use datafusion::{ physical_plan::metrics::ExecutionPlanMetricsSet, test_util::aggr_test_schema, }; + use futures::StreamExt; use object_store::local::LocalFileSystem; @@ -60,12 +62,11 @@ async fn main() -> Result<()> { object_store_url: ObjectStoreUrl::local_filesystem(), file_schema: schema.clone(), file_groups: vec![vec![PartitionedFile::new(path.display().to_string(), 10)]], - statistics: Default::default(), + statistics: Statistics::new_unknown(&schema), projection: Some(vec![12, 0]), limit: Some(5), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let result = diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index a24573c860bb..69f9c9530e87 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -15,28 +15,29 @@ // specific language governing permissions and limitations // under the License. -use async_trait::async_trait; +use std::any::Any; +use std::collections::{BTreeMap, HashMap}; +use std::fmt::{self, Debug, Formatter}; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + use datafusion::arrow::array::{UInt64Builder, UInt8Builder}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::dataframe::DataFrame; -use datafusion::datasource::provider_as_source; -use datafusion::datasource::{TableProvider, TableType}; +use datafusion::datasource::{provider_as_source, TableProvider, TableType}; use datafusion::error::Result; use datafusion::execution::context::{SessionState, TaskContext}; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::memory::MemoryStream; use datafusion::physical_plan::{ project_schema, DisplayAs, DisplayFormatType, ExecutionPlan, - SendableRecordBatchStream, Statistics, + SendableRecordBatchStream, }; use datafusion::prelude::*; use datafusion_expr::{Expr, LogicalPlanBuilder}; -use std::any::Any; -use std::collections::{BTreeMap, HashMap}; -use std::fmt::{self, Debug, Formatter}; -use std::sync::{Arc, Mutex}; -use std::time::Duration; + +use async_trait::async_trait; use tokio::time::timeout; /// This example demonstrates executing a simple query against a custom datasource @@ -79,7 +80,7 @@ async fn search_accounts( timeout(Duration::from_secs(10), async move { let result = dataframe.collect().await.unwrap(); - let record_batch = result.get(0).unwrap(); + let record_batch = result.first().unwrap(); assert_eq!(expected_result_length, record_batch.column(1).len()); dbg!(record_batch.columns()); @@ -269,8 +270,4 @@ impl ExecutionPlan for CustomExec { None, )?)) } - - fn statistics(&self) -> Statistics { - Statistics::default() - } } diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs index 26fddcd226a9..ea01c53b1c62 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe.rs @@ -18,7 +18,9 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::error::Result; use datafusion::prelude::*; -use std::fs; +use std::fs::File; +use std::io::Write; +use tempfile::tempdir; /// This example demonstrates executing a simple query against an Arrow data source (Parquet) and /// fetching results, using the DataFrame trait @@ -41,12 +43,19 @@ async fn main() -> Result<()> { // print the results df.show().await?; + // create a csv file waiting to be written + let dir = tempdir()?; + let file_path = dir.path().join("example.csv"); + let file = File::create(&file_path)?; + write_csv_file(file); + // Reading CSV file with inferred schema example - let csv_df = example_read_csv_file_with_inferred_schema().await; + let csv_df = + example_read_csv_file_with_inferred_schema(file_path.to_str().unwrap()).await; csv_df.show().await?; // Reading CSV file with defined schema - let csv_df = example_read_csv_file_with_schema().await; + let csv_df = example_read_csv_file_with_schema(file_path.to_str().unwrap()).await; csv_df.show().await?; // Reading PARQUET file and print describe @@ -59,31 +68,28 @@ async fn main() -> Result<()> { } // Function to create an test CSV file -fn create_csv_file(path: String) { +fn write_csv_file(mut file: File) { // Create the data to put into the csv file with headers let content = r#"id,time,vote,unixtime,rating a1,"10 6, 2013",3,1381017600,5.0 a2,"08 9, 2013",2,1376006400,4.5"#; // write the data - fs::write(path, content).expect("Problem with writing file!"); + file.write_all(content.as_ref()) + .expect("Problem with writing file!"); } // Example to read data from a csv file with inferred schema -async fn example_read_csv_file_with_inferred_schema() -> DataFrame { - let path = "example.csv"; - // Create a csv file using the predefined function - create_csv_file(path.to_string()); +async fn example_read_csv_file_with_inferred_schema(file_path: &str) -> DataFrame { // Create a session context let ctx = SessionContext::new(); // Register a lazy DataFrame using the context - ctx.read_csv(path, CsvReadOptions::default()).await.unwrap() + ctx.read_csv(file_path, CsvReadOptions::default()) + .await + .unwrap() } // Example to read csv file with a defined schema for the csv file -async fn example_read_csv_file_with_schema() -> DataFrame { - let path = "example.csv"; - // Create a csv file using the predefined function - create_csv_file(path.to_string()); +async fn example_read_csv_file_with_schema(file_path: &str) -> DataFrame { // Create a session context let ctx = SessionContext::new(); // Define the schema @@ -101,5 +107,5 @@ async fn example_read_csv_file_with_schema() -> DataFrame { ..Default::default() }; // Register a lazy DataFrame by using the context and option provider - ctx.read_csv(path, csv_read_option).await.unwrap() + ctx.read_csv(file_path, csv_read_option).await.unwrap() } diff --git a/datafusion-examples/examples/dataframe_output.rs b/datafusion-examples/examples/dataframe_output.rs new file mode 100644 index 000000000000..c773384dfcd5 --- /dev/null +++ b/datafusion-examples/examples/dataframe_output.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{dataframe::DataFrameWriteOptions, prelude::*}; +use datafusion_common::{parsers::CompressionTypeVariant, DataFusionError}; + +/// This example demonstrates the various methods to write out a DataFrame to local storage. +/// See datafusion-examples/examples/external_dependency/dataframe-to-s3.rs for an example +/// using a remote object store. +#[tokio::main] +async fn main() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + + let mut df = ctx.sql("values ('a'), ('b'), ('c')").await.unwrap(); + + // Ensure the column names and types match the target table + df = df.with_column_renamed("column1", "tablecol1").unwrap(); + + ctx.sql( + "create external table + test(tablecol1 varchar) + stored as parquet + location './datafusion-examples/test_table/'", + ) + .await? + .collect() + .await?; + + // This is equivalent to INSERT INTO test VALUES ('a'), ('b'), ('c'). + // The behavior of write_table depends on the TableProvider's implementation + // of the insert_into method. + df.clone() + .write_table("test", DataFrameWriteOptions::new()) + .await?; + + df.clone() + .write_parquet( + "./datafusion-examples/test_parquet/", + DataFrameWriteOptions::new(), + None, + ) + .await?; + + df.clone() + .write_csv( + "./datafusion-examples/test_csv/", + // DataFrameWriteOptions contains options which control how data is written + // such as compression codec + DataFrameWriteOptions::new().with_compression(CompressionTypeVariant::GZIP), + None, + ) + .await?; + + df.clone() + .write_json( + "./datafusion-examples/test_json/", + DataFrameWriteOptions::new(), + ) + .await?; + + Ok(()) +} diff --git a/datafusion-examples/examples/dataframe_subquery.rs b/datafusion-examples/examples/dataframe_subquery.rs index 94049e59b3ab..9fb61008b9f6 100644 --- a/datafusion-examples/examples/dataframe_subquery.rs +++ b/datafusion-examples/examples/dataframe_subquery.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow_schema::DataType; use std::sync::Arc; use datafusion::error::Result; @@ -38,7 +39,7 @@ async fn main() -> Result<()> { Ok(()) } -//select c1,c2 from t1 where (select avg(t2.c2) from t2 where t1.c1 = t2.c1)>0 limit 10; +//select c1,c2 from t1 where (select avg(t2.c2) from t2 where t1.c1 = t2.c1)>0 limit 3; async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? @@ -46,7 +47,7 @@ async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { scalar_subquery(Arc::new( ctx.table("t2") .await? - .filter(col("t1.c1").eq(col("t2.c1")))? + .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? .aggregate(vec![], vec![avg(col("t2.c2"))])? .select(vec![avg(col("t2.c2"))])? .into_unoptimized_plan(), @@ -60,7 +61,7 @@ async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { Ok(()) } -//SELECT t1.c1, t1.c2 FROM t1 WHERE t1.c2 in (select max(t2.c2) from t2 where t2.c1 > 0 ) limit 10 +//SELECT t1.c1, t1.c2 FROM t1 WHERE t1.c2 in (select max(t2.c2) from t2 where t2.c1 > 0 ) limit 3; async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? @@ -82,14 +83,14 @@ async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { Ok(()) } -//SELECT t1.c1, t1.c2 FROM t1 WHERE EXISTS (select t2.c2 from t2 where t1.c1 = t2.c1) limit 10 +//SELECT t1.c1, t1.c2 FROM t1 WHERE EXISTS (select t2.c2 from t2 where t1.c1 = t2.c1) limit 3; async fn where_exist_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? .filter(exists(Arc::new( ctx.table("t2") .await? - .filter(col("t1.c1").eq(col("t2.c1")))? + .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? .select(vec![col("t2.c2")])? .into_unoptimized_plan(), )))? diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 97abf4d552a9..715e1ff2dce6 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -15,28 +15,43 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::{BooleanArray, Int32Array}; +use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::error::Result; use datafusion::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; use datafusion::physical_expr::execution_props::ExecutionProps; +use datafusion::physical_expr::{ + analyze, create_physical_expr, AnalysisContext, ExprBoundaries, PhysicalExpr, +}; use datafusion::prelude::*; use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::expr::BinaryExpr; -use datafusion_expr::Operator; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::{ColumnarValue, ExprSchemable, Operator}; +use std::sync::Arc; /// This example demonstrates the DataFusion [`Expr`] API. /// /// DataFusion comes with a powerful and extensive system for /// representing and manipulating expressions such as `A + 5` and `X -/// IN ('foo', 'bar', 'baz')` and many other constructs. +/// IN ('foo', 'bar', 'baz')`. +/// +/// In addition to building and manipulating [`Expr`]s, DataFusion +/// also comes with APIs for evaluation, simplification, and analysis. +/// +/// The code in this example shows how to: +/// 1. Create [`Exprs`] using different APIs: [`main`]` +/// 2. Evaluate [`Exprs`] against data: [`evaluate_demo`] +/// 3. Simplify expressions: [`simplify_demo`] +/// 4. Analyze predicates for boundary ranges: [`range_analysis_demo`] #[tokio::main] async fn main() -> Result<()> { // The easiest way to do create expressions is to use the - // "fluent"-style API, like this: + // "fluent"-style API: let expr = col("a") + lit(5); - // this creates the same expression as the following though with - // much less code, + // The same same expression can be created directly, with much more code: let expr2 = Expr::BinaryExpr(BinaryExpr::new( Box::new(col("a")), Operator::Plus, @@ -44,15 +59,51 @@ async fn main() -> Result<()> { )); assert_eq!(expr, expr2); + // See how to evaluate expressions + evaluate_demo()?; + + // See how to simplify expressions simplify_demo()?; + // See how to analyze ranges in expressions + range_analysis_demo()?; + + Ok(()) +} + +/// DataFusion can also evaluate arbitrary expressions on Arrow arrays. +fn evaluate_demo() -> Result<()> { + // For example, let's say you have some integers in an array + let batch = RecordBatch::try_from_iter([( + "a", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 8, 7, 4])) as _, + )])?; + + // If you want to find all rows where the expression `a < 5 OR a = 8` is true + let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); + + // First, you make a "physical expression" from the logical `Expr` + let physical_expr = physical_expr(&batch.schema(), expr)?; + + // Now, you can evaluate the expression against the RecordBatch + let result = physical_expr.evaluate(&batch)?; + + // The result contain an array that is true only for where `a < 5 OR a = 8` + let expected_result = Arc::new(BooleanArray::from(vec![ + true, false, false, false, true, false, true, + ])) as _; + assert!( + matches!(&result, ColumnarValue::Array(r) if r == &expected_result), + "result: {:?}", + result + ); + Ok(()) } -/// In addition to easy construction, DataFusion exposes APIs for -/// working with and simplifying such expressions that call into the -/// same powerful and extensive implementation used for the query -/// engine. +/// In addition to easy construction, DataFusion exposes APIs for simplifying +/// such expression so they are more efficient to evaluate. This code is also +/// used by the query engine to optimize queries. fn simplify_demo() -> Result<()> { // For example, lets say you have has created an expression such // ts = to_timestamp("2020-09-08T12:00:00+00:00") @@ -94,7 +145,7 @@ fn simplify_demo() -> Result<()> { make_field("b", DataType::Boolean), ]) .to_dfschema_ref()?; - let context = SimplifyContext::new(&props).with_schema(schema); + let context = SimplifyContext::new(&props).with_schema(schema.clone()); let simplifier = ExprSimplifier::new(context); // basic arithmetic simplification @@ -120,6 +171,64 @@ fn simplify_demo() -> Result<()> { col("i").lt(lit(10)) ); + // String --> Date simplification + // `cast('2020-09-01' as date)` --> 18500 + assert_eq!( + simplifier.simplify(lit("2020-09-01").cast_to(&DataType::Date32, &schema)?)?, + lit(ScalarValue::Date32(Some(18506))) + ); + + Ok(()) +} + +/// DataFusion also has APIs for analyzing predicates (boolean expressions) to +/// determine any ranges restrictions on the inputs required for the predicate +/// evaluate to true. +fn range_analysis_demo() -> Result<()> { + // For example, let's say you are interested in finding data for all days + // in the month of September, 2020 + let september_1 = ScalarValue::Date32(Some(18506)); // 2020-09-01 + let october_1 = ScalarValue::Date32(Some(18536)); // 2020-10-01 + + // The predicate to find all such days could be + // `date > '2020-09-01' AND date < '2020-10-01'` + let expr = col("date") + .gt(lit(september_1.clone())) + .and(col("date").lt(lit(october_1.clone()))); + + // Using the analysis API, DataFusion can determine that the value of `date` + // must be in the range `['2020-09-01', '2020-10-01']`. If your data is + // organized in files according to day, this information permits skipping + // entire files without reading them. + // + // While this simple example could be handled with a special case, the + // DataFusion API handles arbitrary expressions (so for example, you don't + // have to handle the case where the predicate clauses are reversed such as + // `date < '2020-10-01' AND date > '2020-09-01'` + + // As always, we need to tell DataFusion the type of column "date" + let schema = Schema::new(vec![make_field("date", DataType::Date32)]); + + // You can provide DataFusion any known boundaries on the values of `date` + // (for example, maybe you know you only have data up to `2020-09-15`), but + // in this case, let's say we don't know any boundaries beforehand so we use + // `try_new_unknown` + let boundaries = ExprBoundaries::try_new_unbounded(&schema)?; + + // Now, we invoke the analysis code to perform the range analysis + let physical_expr = physical_expr(&schema, expr)?; + let analysis_result = + analyze(&physical_expr, AnalysisContext::new(boundaries), &schema)?; + + // The results of the analysis is an range, encoded as an `Interval`, for + // each column in the schema, that must be true in order for the predicate + // to be true. + // + // In this case, we can see that, as expected, `analyze` has figured out + // that in this case, `date` must be in the range `['2020-09-01', '2020-10-01']` + let expected_range = Interval::try_new(september_1, october_1)?; + assert_eq!(analysis_result.boundaries[0].interval, expected_range); + Ok(()) } @@ -132,3 +241,18 @@ fn make_ts_field(name: &str) -> Field { let tz = None; make_field(name, DataType::Timestamp(TimeUnit::Nanosecond, tz)) } + +/// Build a physical expression from a logical one, after applying simplification and type coercion +pub fn physical_expr(schema: &Schema, expr: Expr) -> Result> { + let df_schema = schema.clone().to_dfschema_ref()?; + + // Simplify + let props = ExecutionProps::new(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&props).with_schema(df_schema.clone())); + + // apply type coercion here to ensure types match + let expr = simplifier.coerce(expr, df_schema.clone())?; + + create_physical_expr(&expr, df_schema.as_ref(), schema, &props) +} diff --git a/datafusion-examples/examples/catalog.rs b/datafusion-examples/examples/external_dependency/catalog.rs similarity index 100% rename from datafusion-examples/examples/catalog.rs rename to datafusion-examples/examples/external_dependency/catalog.rs diff --git a/datafusion-examples/examples/dataframe-to-s3.rs b/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs similarity index 100% rename from datafusion-examples/examples/dataframe-to-s3.rs rename to datafusion-examples/examples/external_dependency/dataframe-to-s3.rs diff --git a/datafusion-examples/examples/query-aws-s3.rs b/datafusion-examples/examples/external_dependency/query-aws-s3.rs similarity index 100% rename from datafusion-examples/examples/query-aws-s3.rs rename to datafusion-examples/examples/external_dependency/query-aws-s3.rs diff --git a/datafusion-examples/examples/flight_client.rs b/datafusion-examples/examples/flight/flight_client.rs similarity index 100% rename from datafusion-examples/examples/flight_client.rs rename to datafusion-examples/examples/flight/flight_client.rs diff --git a/datafusion-examples/examples/flight_server.rs b/datafusion-examples/examples/flight/flight_server.rs similarity index 100% rename from datafusion-examples/examples/flight_server.rs rename to datafusion-examples/examples/flight/flight_server.rs diff --git a/datafusion-examples/examples/flight_sql_server.rs b/datafusion-examples/examples/flight/flight_sql_server.rs similarity index 99% rename from datafusion-examples/examples/flight_sql_server.rs rename to datafusion-examples/examples/flight/flight_sql_server.rs index aad63fc2bca8..ed5b86d0b66c 100644 --- a/datafusion-examples/examples/flight_sql_server.rs +++ b/datafusion-examples/examples/flight/flight_sql_server.rs @@ -105,7 +105,7 @@ impl FlightSqlServiceImpl { let session_config = SessionConfig::from_env() .map_err(|e| Status::internal(format!("Error building plan: {e}")))? .with_information_schema(true); - let ctx = Arc::new(SessionContext::with_config(session_config)); + let ctx = Arc::new(SessionContext::new_with_config(session_config)); let testdata = datafusion::test_util::parquet_test_data(); diff --git a/datafusion-examples/examples/json_opener.rs b/datafusion-examples/examples/json_opener.rs index 74ba6f3852a8..ee33f969caa9 100644 --- a/datafusion-examples/examples/json_opener.rs +++ b/datafusion-examples/examples/json_opener.rs @@ -29,6 +29,8 @@ use datafusion::{ error::Result, physical_plan::metrics::ExecutionPlanMetricsSet, }; +use datafusion_common::Statistics; + use futures::StreamExt; use object_store::ObjectStore; @@ -63,12 +65,11 @@ async fn main() -> Result<()> { object_store_url: ObjectStoreUrl::local_filesystem(), file_schema: schema.clone(), file_groups: vec![vec![PartitionedFile::new(path.to_string(), 10)]], - statistics: Default::default(), + statistics: Statistics::new_unknown(&schema), projection: Some(vec![1, 0]), limit: Some(5), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let result = diff --git a/datafusion-examples/examples/memtable.rs b/datafusion-examples/examples/memtable.rs index bef8f3e5bb8f..5cce578039e7 100644 --- a/datafusion-examples/examples/memtable.rs +++ b/datafusion-examples/examples/memtable.rs @@ -40,7 +40,7 @@ async fn main() -> Result<()> { timeout(Duration::from_secs(10), async move { let result = dataframe.collect().await.unwrap(); - let record_batch = result.get(0).unwrap(); + let record_batch = result.first().unwrap(); assert_eq!(1, record_batch.column(0).len()); dbg!(record_batch.columns()); diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index e657baab3df8..5e95562033e6 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -191,7 +191,7 @@ struct MyContextProvider { } impl ContextProvider for MyContextProvider { - fn get_table_provider(&self, name: TableReference) -> Result> { + fn get_table_source(&self, name: TableReference) -> Result> { if name.table() == "person" { Ok(Arc::new(MyTableSource { schema: Arc::new(Schema::new(vec![ diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 7aec9698d92f..2c797f221b2c 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -154,6 +154,10 @@ async fn main() -> Result<()> { // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); + ctx.register_udaf(geometric_mean.clone()); + + let sql_df = ctx.sql("SELECT geo_mean(a) FROM t").await?; + sql_df.show().await?; // get a DataFrame from the context // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0. diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index dba4385b8eea..39e1e13ce39a 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -29,23 +29,23 @@ use datafusion::{error::Result, physical_plan::functions::make_scalar_function}; use datafusion_common::cast::as_float64_array; use std::sync::Arc; -// create local execution context with an in-memory table +/// create local execution context with an in-memory table: +/// +/// ```text +/// +-----+-----+ +/// | a | b | +/// +-----+-----+ +/// | 2.1 | 1.0 | +/// | 3.1 | 2.0 | +/// | 4.1 | 3.0 | +/// | 5.1 | 4.0 | +/// +-----+-----+ +/// ``` fn create_context() -> Result { - use datafusion::arrow::datatypes::{Field, Schema}; - // define a schema. - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Float32, false), - Field::new("b", DataType::Float64, false), - ])); - // define data. - let batch = RecordBatch::try_new( - schema, - vec![ - Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])), - Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), - ], - )?; + let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); + let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; // declare a new context. In spark API, this corresponds to a new spark SQLsession let ctx = SessionContext::new(); @@ -140,5 +140,11 @@ async fn main() -> Result<()> { // print the results df.show().await?; + // Given that `pow` is registered in the context, we can also use it in SQL: + let sql_df = ctx.sql("SELECT pow(a, b) FROM t").await?; + + // print the results + sql_df.show().await?; + Ok(()) } diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs new file mode 100644 index 000000000000..f1d763ba6e41 --- /dev/null +++ b/datafusion-examples/examples/simple_udtf.rs @@ -0,0 +1,178 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::csv::reader::Format; +use arrow::csv::ReaderBuilder; +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; +use datafusion::error::Result; +use datafusion::execution::context::{ExecutionProps, SessionState}; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use datafusion_common::{plan_err, DataFusionError, ScalarValue}; +use datafusion_expr::{Expr, TableType}; +use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; +use std::fs::File; +use std::io::Seek; +use std::path::Path; +use std::sync::Arc; + +// To define your own table function, you only need to do the following 3 things: +// 1. Implement your own [`TableProvider`] +// 2. Implement your own [`TableFunctionImpl`] and return your [`TableProvider`] +// 3. Register the function using [`SessionContext::register_udtf`] + +/// This example demonstrates how to register a TableFunction +#[tokio::main] +async fn main() -> Result<()> { + // create local execution context + let ctx = SessionContext::new(); + + // register the table function that will be called in SQL statements by `read_csv` + ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {})); + + let testdata = datafusion::test_util::arrow_test_data(); + let csv_file = format!("{testdata}/csv/aggregate_test_100.csv"); + + // Pass 2 arguments, read csv with at most 2 rows (simplify logic makes 1+1 --> 2) + let df = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}', 1 + 1);").as_str()) + .await?; + df.show().await?; + + // just run, return all rows + let df = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str()) + .await?; + df.show().await?; + + Ok(()) +} + +/// Table Function that mimics the [`read_csv`] function in DuckDB. +/// +/// Usage: `read_csv(filename, [limit])` +/// +/// [`read_csv`]: https://duckdb.org/docs/data/csv/overview.html +struct LocalCsvTable { + schema: SchemaRef, + limit: Option, + batches: Vec, +} + +#[async_trait] +impl TableProvider for LocalCsvTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let batches = if let Some(max_return_lines) = self.limit { + // get max return rows from self.batches + let mut batches = vec![]; + let mut lines = 0; + for batch in &self.batches { + let batch_lines = batch.num_rows(); + if lines + batch_lines > max_return_lines { + let batch_lines = max_return_lines - lines; + batches.push(batch.slice(0, batch_lines)); + break; + } else { + batches.push(batch.clone()); + lines += batch_lines; + } + } + batches + } else { + self.batches.clone() + }; + Ok(Arc::new(MemoryExec::try_new( + &[batches], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} + +struct LocalCsvTableFunc {} + +impl TableFunctionImpl for LocalCsvTableFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else { + return plan_err!("read_csv requires at least one string argument"); + }; + + let limit = exprs + .get(1) + .map(|expr| { + // try to simpify the expression, so 1+2 becomes 3, for example + let execution_props = ExecutionProps::new(); + let info = SimplifyContext::new(&execution_props); + let expr = ExprSimplifier::new(info).simplify(expr.clone())?; + + if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr { + Ok(limit as usize) + } else { + plan_err!("Limit must be an integer") + } + }) + .transpose()?; + + let (schema, batches) = read_csv_batches(path)?; + + let table = LocalCsvTable { + schema, + limit, + batches, + }; + Ok(Arc::new(table)) + } +} + +fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec)> { + let mut file = File::open(csv_path)?; + let (schema, _) = Format::default().infer_schema(&mut file, None)?; + file.rewind()?; + + let reader = ReaderBuilder::new(Arc::new(schema.clone())) + .with_header(true) + .build(file)?; + let mut batches = vec![]; + for bacth in reader { + batches.push(bacth?); + } + let schema = Arc::new(schema); + Ok((schema, batches)) +} diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs index 39042a35629b..0d04c093e147 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/simple_udwf.rs @@ -36,7 +36,7 @@ async fn create_context() -> Result { // declare a table in memory. In spark API, this corresponds to createDataFrame(...). println!("pwd: {}", std::env::current_dir().unwrap().display()); - let csv_path = "datafusion/core/tests/data/cars.csv".to_string(); + let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); let read_options = CsvReadOptions::default().has_header(true); ctx.register_csv("cars", &csv_path, read_options).await?; @@ -89,7 +89,7 @@ async fn main() -> Result<()> { "SELECT \ car, \ speed, \ - smooth_it(speed) OVER (PARTITION BY car ORDER BY time),\ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\ time \ from cars \ ORDER BY \ @@ -109,7 +109,7 @@ async fn main() -> Result<()> { "SELECT \ car, \ speed, \ - smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS smooth_speed,\ time \ from cars \ ORDER BY \ diff --git a/datafusion/CHANGELOG.md b/datafusion/CHANGELOG.md index 0a25b747c739..d64bbeda877d 100644 --- a/datafusion/CHANGELOG.md +++ b/datafusion/CHANGELOG.md @@ -19,6 +19,9 @@ # Changelog +- [34.0.0](../dev/changelog/34.0.0.md) +- [33.0.0](../dev/changelog/33.0.0.md) +- [32.0.0](../dev/changelog/32.0.0.md) - [31.0.0](../dev/changelog/31.0.0.md) - [30.0.0](../dev/changelog/30.0.0.md) - [29.0.0](../dev/changelog/29.0.0.md) diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index f2b8f1a1e4be..b69e1f7f3d10 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -19,9 +19,9 @@ name = "datafusion-common" description = "Common functionality for DataFusion query engine" keywords = ["arrow", "query", "sql"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -35,18 +35,29 @@ path = "src/lib.rs" [features] avro = ["apache-avro"] backtrace = [] -default = ["parquet"] -pyarrow = ["pyo3", "arrow/pyarrow"] +pyarrow = ["pyo3", "arrow/pyarrow", "parquet"] [dependencies] -apache-avro = { version = "0.16", default-features = false, features = ["snappy"], optional = true } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } +apache-avro = { version = "0.16", default-features = false, features = [ + "bzip", + "snappy", + "xz", + "zstandard", +], optional = true } arrow = { workspace = true } arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-schema = { workspace = true } chrono = { workspace = true } -num_cpus = "1.13.0" -object_store = { version = "0.7.0", default-features = false, optional = true } -parquet = { workspace = true, optional = true } -pyo3 = { version = "0.19.0", optional = true } +half = { version = "2.1", default-features = false } +libc = "0.2.140" +num_cpus = { workspace = true } +object_store = { workspace = true, optional = true } +parquet = { workspace = true, optional = true, default-features = true } +pyo3 = { version = "0.20.0", optional = true } sqlparser = { workspace = true } [dev-dependencies] diff --git a/datafusion/common/README.md b/datafusion/common/README.md index 9bccf3f18b7f..524ab4420d2a 100644 --- a/datafusion/common/README.md +++ b/datafusion/common/README.md @@ -19,7 +19,7 @@ # DataFusion Common -[DataFusion](df) is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. This crate is a submodule of DataFusion that provides common data types and utilities. diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 4356f36b18d8..088f03e002ed 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -181,23 +181,17 @@ pub fn as_timestamp_second_array(array: &dyn Array) -> Result<&TimestampSecondAr } // Downcast ArrayRef to IntervalYearMonthArray -pub fn as_interval_ym_array( - array: &dyn Array, -) -> Result<&IntervalYearMonthArray, DataFusionError> { +pub fn as_interval_ym_array(array: &dyn Array) -> Result<&IntervalYearMonthArray> { Ok(downcast_value!(array, IntervalYearMonthArray)) } // Downcast ArrayRef to IntervalDayTimeArray -pub fn as_interval_dt_array( - array: &dyn Array, -) -> Result<&IntervalDayTimeArray, DataFusionError> { +pub fn as_interval_dt_array(array: &dyn Array) -> Result<&IntervalDayTimeArray> { Ok(downcast_value!(array, IntervalDayTimeArray)) } // Downcast ArrayRef to IntervalMonthDayNanoArray -pub fn as_interval_mdn_array( - array: &dyn Array, -) -> Result<&IntervalMonthDayNanoArray, DataFusionError> { +pub fn as_interval_mdn_array(array: &dyn Array) -> Result<&IntervalMonthDayNanoArray> { Ok(downcast_value!(array, IntervalMonthDayNanoArray)) } diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index 2e729c128e73..f0edc7175948 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -17,6 +17,7 @@ //! Column +use crate::error::_schema_err; use crate::utils::{parse_identifiers_normalized, quote_identifier}; use crate::{DFSchema, DataFusionError, OwnedTableReference, Result, SchemaError}; use std::collections::HashSet; @@ -211,13 +212,13 @@ impl Column { } } - Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { + _schema_err!(SchemaError::FieldNotFound { field: Box::new(Column::new(self.relation.clone(), self.name)), valid_fields: schemas .iter() .flat_map(|s| s.fields().iter().map(|f| f.qualified_column())) .collect(), - })) + }) } /// Qualify column if not done yet. @@ -299,23 +300,21 @@ impl Column { } // If not due to USING columns then due to ambiguous column name - return Err(DataFusionError::SchemaError( - SchemaError::AmbiguousReference { - field: Column::new_unqualified(self.name), - }, - )); + return _schema_err!(SchemaError::AmbiguousReference { + field: Column::new_unqualified(self.name), + }); } } } - Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { + _schema_err!(SchemaError::FieldNotFound { field: Box::new(self), valid_fields: schemas .iter() .flat_map(|s| s.iter()) .flat_map(|s| s.fields().iter().map(|f| f.qualified_column())) .collect(), - })) + }) } } diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index b34c64ff8893..cc60f5d1ed07 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -254,6 +254,35 @@ config_namespace! { /// Number of files to read in parallel when inferring schema and statistics pub meta_fetch_concurrency: usize, default = 32 + + /// Guarantees a minimum level of output files running in parallel. + /// RecordBatches will be distributed in round robin fashion to each + /// parallel writer. Each writer is closed and a new file opened once + /// soft_max_rows_per_output_file is reached. + pub minimum_parallel_output_files: usize, default = 4 + + /// Target number of rows in output files when writing multiple. + /// This is a soft max, so it can be exceeded slightly. There also + /// will be one file smaller than the limit if the total + /// number of rows written is not roughly divisible by the soft max + pub soft_max_rows_per_output_file: usize, default = 50000000 + + /// This is the maximum number of RecordBatches buffered + /// for each output file being worked. Higher values can potentially + /// give faster write performance at the cost of higher peak + /// memory consumption + pub max_buffered_batches_per_output_file: usize, default = 2 + + /// Should sub directories be ignored when scanning directories for data + /// files. Defaults to true (ignores subdirectories), consistent with + /// Hive. Note that this setting does not affect reading partitioned + /// tables (e.g. `/table/year=2021/month=01/data.parquet`). + pub listing_table_ignore_subdirectory: bool, default = true + + /// Should DataFusion support recursive CTEs + /// Defaults to false since this feature is a work in progress and may not + /// behave as expected + pub enable_recursive_ctes: bool, default = false } } @@ -307,7 +336,7 @@ config_namespace! { /// lzo, brotli(level), lz4, zstd(level), and lz4_raw. /// These values are not case sensitive. If NULL, uses /// default parquet writer setting - pub compression: Option, default = None + pub compression: Option, default = Some("zstd(3)".into()) /// Sets if dictionary encoding is enabled. If NULL, uses /// default parquet writer setting @@ -332,7 +361,7 @@ config_namespace! { /// Sets "created by" property pub created_by: String, default = concat!("datafusion version ", env!("CARGO_PKG_VERSION")).into() - /// Sets column index trucate length + /// Sets column index truncate length pub column_index_truncate_length: Option, default = None /// Sets best effort maximum number of rows in data page @@ -358,12 +387,32 @@ config_namespace! { pub bloom_filter_ndv: Option, default = None /// Controls whether DataFusion will attempt to speed up writing - /// large parquet files by first writing multiple smaller files - /// and then stitching them together into a single large file. - /// This will result in faster write speeds, but higher memory usage. - /// Also currently unsupported are bloom filters and column indexes - /// when single_file_parallelism is enabled. - pub allow_single_file_parallelism: bool, default = false + /// parquet files by serializing them in parallel. Each column + /// in each row group in each output file are serialized in parallel + /// leveraging a maximum possible core count of n_files*n_row_groups*n_columns. + pub allow_single_file_parallelism: bool, default = true + + /// By default parallel parquet writer is tuned for minimum + /// memory usage in a streaming execution plan. You may see + /// a performance benefit when writing large parquet files + /// by increasing maximum_parallel_row_group_writers and + /// maximum_buffered_record_batches_per_stream if your system + /// has idle cores and can tolerate additional memory usage. + /// Boosting these values is likely worthwhile when + /// writing out already in-memory data, such as from a cached + /// data frame. + pub maximum_parallel_row_group_writers: usize, default = 1 + + /// By default parallel parquet writer is tuned for minimum + /// memory usage in a streaming execution plan. You may see + /// a performance benefit when writing large parquet files + /// by increasing maximum_parallel_row_group_writers and + /// maximum_buffered_record_batches_per_stream if your system + /// has idle cores and can tolerate additional memory usage. + /// Boosting these values is likely worthwhile when + /// writing out already in-memory data, such as from a cached + /// data frame. + pub maximum_buffered_record_batches_per_stream: usize, default = 2 } } @@ -388,6 +437,11 @@ config_namespace! { config_namespace! { /// Options related to query optimization pub struct OptimizerOptions { + /// When set to true, the optimizer will push a limit operation into + /// grouped aggregations which have no aggregate expressions, as a soft limit, + /// emitting groups once the limit is reached, before all rows in the group are read. + pub enable_distinct_aggregation_soft_limit: bool, default = true + /// When set to true, the physical plan optimizer will try to add round robin /// repartitioning to increase parallelism to leverage more CPU cores pub enable_round_robin_repartition: bool, default = true @@ -453,11 +507,13 @@ config_namespace! { /// ``` pub repartition_sorts: bool, default = true - /// When true, DataFusion will opportunistically remove sorts by replacing - /// `RepartitionExec` with `SortPreservingRepartitionExec`, and - /// `CoalescePartitionsExec` with `SortPreservingMergeExec`, - /// even when the query is bounded. - pub bounded_order_preserving_variants: bool, default = false + /// When true, DataFusion will opportunistically remove sorts when the data is already sorted, + /// (i.e. setting `preserve_order` to true on `RepartitionExec` and + /// using `SortPreservingMergeExec`) + /// + /// When false, DataFusion will maximize plan parallelism using + /// `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. + pub prefer_existing_sort: bool, default = false /// When set to true, the logical plan optimizer will produce warning /// messages if any optimization rules produce errors and then proceed to the next @@ -478,6 +534,11 @@ config_namespace! { /// The maximum estimated size in bytes for one input side of a HashJoin /// will be collected into a single partition pub hash_join_single_partition_threshold: usize, default = 1024 * 1024 + + /// The default filter selectivity used by Filter Statistics + /// when an exact selectivity cannot be determined. Valid values are + /// between 0 (no selectivity) and 100 (all rows are selected). + pub default_filter_selectivity: u8, default = 20 } } @@ -831,6 +892,7 @@ config_field!(String); config_field!(bool); config_field!(usize); config_field!(f64); +config_field!(u8); config_field!(u64); /// An implementation trait used to recursively walk configuration diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index e015ef5c4082..85b97aac037d 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -24,7 +24,10 @@ use std::fmt::{Display, Formatter}; use std::hash::Hash; use std::sync::Arc; -use crate::error::{unqualified_field_not_found, DataFusionError, Result, SchemaError}; +use crate::error::{ + unqualified_field_not_found, DataFusionError, Result, SchemaError, _plan_err, + _schema_err, +}; use crate::{ field_not_found, Column, FunctionalDependencies, OwnedTableReference, TableReference, }; @@ -32,10 +35,75 @@ use crate::{ use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; -/// A reference-counted reference to a `DFSchema`. +/// A reference-counted reference to a [DFSchema]. pub type DFSchemaRef = Arc; -/// DFSchema wraps an Arrow schema and adds relation names +/// DFSchema wraps an Arrow schema and adds relation names. +/// +/// The schema may hold the fields across multiple tables. Some fields may be +/// qualified and some unqualified. A qualified field is a field that has a +/// relation name associated with it. +/// +/// Unqualified fields must be unique not only amongst themselves, but also must +/// have a distinct name from any qualified field names. This allows finding a +/// qualified field by name to be possible, so long as there aren't multiple +/// qualified fields with the same name. +/// +/// There is an alias to `Arc` named [DFSchemaRef]. +/// +/// # Creating qualified schemas +/// +/// Use [DFSchema::try_from_qualified_schema] to create a qualified schema from +/// an Arrow schema. +/// +/// ```rust +/// use datafusion_common::{DFSchema, Column}; +/// use arrow_schema::{DataType, Field, Schema}; +/// +/// let arrow_schema = Schema::new(vec![ +/// Field::new("c1", DataType::Int32, false), +/// ]); +/// +/// let df_schema = DFSchema::try_from_qualified_schema("t1", &arrow_schema).unwrap(); +/// let column = Column::from_qualified_name("t1.c1"); +/// assert!(df_schema.has_column(&column)); +/// +/// // Can also access qualified fields with unqualified name, if it's unambiguous +/// let column = Column::from_qualified_name("c1"); +/// assert!(df_schema.has_column(&column)); +/// ``` +/// +/// # Creating unqualified schemas +/// +/// Create an unqualified schema using TryFrom: +/// +/// ```rust +/// use datafusion_common::{DFSchema, Column}; +/// use arrow_schema::{DataType, Field, Schema}; +/// +/// let arrow_schema = Schema::new(vec![ +/// Field::new("c1", DataType::Int32, false), +/// ]); +/// +/// let df_schema = DFSchema::try_from(arrow_schema).unwrap(); +/// let column = Column::new_unqualified("c1"); +/// assert!(df_schema.has_column(&column)); +/// ``` +/// +/// # Converting back to Arrow schema +/// +/// Use the `Into` trait to convert `DFSchema` into an Arrow schema: +/// +/// ```rust +/// use datafusion_common::{DFSchema, DFField}; +/// use arrow_schema::Schema; +/// +/// let df_schema = DFSchema::new(vec![ +/// DFField::new_unqualified("c1", arrow::datatypes::DataType::Int32, false), +/// ]).unwrap(); +/// let schema = Schema::from(df_schema); +/// assert_eq!(schema.fields().len(), 1); +/// ``` #[derive(Debug, Clone, PartialEq, Eq)] pub struct DFSchema { /// Fields @@ -74,11 +142,9 @@ impl DFSchema { if let Some(qualifier) = field.qualifier() { qualified_names.insert((qualifier, field.name())); } else if !unqualified_names.insert(field.name()) { - return Err(DataFusionError::SchemaError( - SchemaError::DuplicateUnqualifiedField { - name: field.name().to_string(), - }, - )); + return _schema_err!(SchemaError::DuplicateUnqualifiedField { + name: field.name().to_string(), + }); } } @@ -92,14 +158,12 @@ impl DFSchema { qualified_names.sort(); for (qualifier, name) in &qualified_names { if unqualified_names.contains(name) { - return Err(DataFusionError::SchemaError( - SchemaError::AmbiguousReference { - field: Column { - relation: Some((*qualifier).clone()), - name: name.to_string(), - }, - }, - )); + return _schema_err!(SchemaError::AmbiguousReference { + field: Column { + relation: Some((*qualifier).clone()), + name: name.to_string(), + } + }); } } Ok(Self { @@ -110,6 +174,9 @@ impl DFSchema { } /// Create a `DFSchema` from an Arrow schema and a given qualifier + /// + /// To create a schema from an Arrow schema without a qualifier, use + /// `DFSchema::try_from`. pub fn try_from_qualified_schema<'a>( qualifier: impl Into>, schema: &Schema, @@ -129,9 +196,16 @@ impl DFSchema { pub fn with_functional_dependencies( mut self, functional_dependencies: FunctionalDependencies, - ) -> Self { - self.functional_dependencies = functional_dependencies; - self + ) -> Result { + if functional_dependencies.is_valid(self.fields.len()) { + self.functional_dependencies = functional_dependencies; + Ok(self) + } else { + _plan_err!( + "Invalid functional dependency: {:?}", + functional_dependencies + ) + } } /// Create a new schema that contains the fields from this schema followed by the fields @@ -153,9 +227,9 @@ impl DFSchema { for field in other_schema.fields() { // skip duplicate columns let duplicated_field = match field.qualifier() { - Some(q) => self.field_with_name(Some(q), field.name()).is_ok(), + Some(q) => self.has_column_with_qualified_name(q, field.name()), // for unqualified columns, check as unqualified name - None => self.field_with_unqualified_name(field.name()).is_ok(), + None => self.has_column_with_unqualified_name(field.name()), }; if !duplicated_field { self.fields.push(field.clone()); @@ -187,10 +261,10 @@ impl DFSchema { match &self.fields[i].qualifier { Some(qualifier) => { if (qualifier.to_string() + "." + self.fields[i].name()) == name { - return Err(DataFusionError::Plan(format!( + return _plan_err!( "Fully qualified field name '{name}' was supplied to `index_of` \ which is deprecated. Please use `index_of_column_by_name` instead" - ))); + ); } } None => (), @@ -270,6 +344,22 @@ impl DFSchema { .collect() } + /// Find all fields indices having the given qualifier + pub fn fields_indices_with_qualified( + &self, + qualifier: &TableReference, + ) -> Vec { + self.fields + .iter() + .enumerate() + .filter_map(|(idx, field)| { + field + .qualifier() + .and_then(|q| q.eq(qualifier).then_some(idx)) + }) + .collect() + } + /// Find all fields match the given name pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> { self.fields @@ -299,14 +389,12 @@ impl DFSchema { if fields_without_qualifier.len() == 1 { Ok(fields_without_qualifier[0]) } else { - Err(DataFusionError::SchemaError( - SchemaError::AmbiguousReference { - field: Column { - relation: None, - name: name.to_string(), - }, + _schema_err!(SchemaError::AmbiguousReference { + field: Column { + relation: None, + name: name.to_string(), }, - )) + }) } } } @@ -378,23 +466,44 @@ impl DFSchema { .zip(arrow_schema.fields().iter()) .try_for_each(|(l_field, r_field)| { if !can_cast_types(r_field.data_type(), l_field.data_type()) { - Err(DataFusionError::Plan( - format!("Column {} (type: {}) is not compatible with column {} (type: {})", + _plan_err!("Column {} (type: {}) is not compatible with column {} (type: {})", r_field.name(), r_field.data_type(), l_field.name(), - l_field.data_type()))) + l_field.data_type()) } else { Ok(()) } }) } + /// Returns true if the two schemas have the same qualified named + /// fields with logically equivalent data types. Returns false otherwise. + /// + /// Use [DFSchema]::equivalent_names_and_types for stricter semantic type + /// equivalence checking. + pub fn logically_equivalent_names_and_types(&self, other: &Self) -> bool { + if self.fields().len() != other.fields().len() { + return false; + } + let self_fields = self.fields().iter(); + let other_fields = other.fields().iter(); + self_fields.zip(other_fields).all(|(f1, f2)| { + f1.qualifier() == f2.qualifier() + && f1.name() == f2.name() + && Self::datatype_is_logically_equal(f1.data_type(), f2.data_type()) + }) + } + /// Returns true if the two schemas have the same qualified named /// fields with the same data types. Returns false otherwise. /// /// This is a specialized version of Eq that ignores differences /// in nullability and metadata. + /// + /// Use [DFSchema]::logically_equivalent_names_and_types for a weaker + /// logical type checking, which for example would consider a dictionary + /// encoded UTF8 array to be equivalent to a plain UTF8 array. pub fn equivalent_names_and_types(&self, other: &Self) -> bool { if self.fields().len() != other.fields().len() { return false; @@ -408,6 +517,46 @@ impl DFSchema { }) } + /// Checks if two [`DataType`]s are logically equal. This is a notably weaker constraint + /// than datatype_is_semantically_equal in that a Dictionary type is logically + /// equal to a plain V type, but not semantically equal. Dictionary is also + /// logically equal to Dictionary. + fn datatype_is_logically_equal(dt1: &DataType, dt2: &DataType) -> bool { + // check nested fields + match (dt1, dt2) { + (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => { + v1.as_ref() == v2.as_ref() + } + (DataType::Dictionary(_, v1), othertype) => v1.as_ref() == othertype, + (othertype, DataType::Dictionary(_, v1)) => v1.as_ref() == othertype, + (DataType::List(f1), DataType::List(f2)) + | (DataType::LargeList(f1), DataType::LargeList(f2)) + | (DataType::FixedSizeList(f1, _), DataType::FixedSizeList(f2, _)) + | (DataType::Map(f1, _), DataType::Map(f2, _)) => { + Self::field_is_logically_equal(f1, f2) + } + (DataType::Struct(fields1), DataType::Struct(fields2)) => { + let iter1 = fields1.iter(); + let iter2 = fields2.iter(); + fields1.len() == fields2.len() && + // all fields have to be the same + iter1 + .zip(iter2) + .all(|(f1, f2)| Self::field_is_logically_equal(f1, f2)) + } + (DataType::Union(fields1, _), DataType::Union(fields2, _)) => { + let iter1 = fields1.iter(); + let iter2 = fields2.iter(); + fields1.len() == fields2.len() && + // all fields have to be the same + iter1 + .zip(iter2) + .all(|((t1, f1), (t2, f2))| t1 == t2 && Self::field_is_logically_equal(f1, f2)) + } + _ => dt1 == dt2, + } + } + /// Returns true of two [`DataType`]s are semantically equal (same /// name and type), ignoring both metadata and nullability. /// @@ -443,10 +592,23 @@ impl DFSchema { .zip(iter2) .all(|((t1, f1), (t2, f2))| t1 == t2 && Self::field_is_semantically_equal(f1, f2)) } + ( + DataType::Decimal128(_l_precision, _l_scale), + DataType::Decimal128(_r_precision, _r_scale), + ) => true, + ( + DataType::Decimal256(_l_precision, _l_scale), + DataType::Decimal256(_r_precision, _r_scale), + ) => true, _ => dt1 == dt2, } } + fn field_is_logically_equal(f1: &Field, f2: &Field) -> bool { + f1.name() == f2.name() + && Self::datatype_is_logically_equal(f1.data_type(), f2.data_type()) + } + fn field_is_semantically_equal(f1: &Field, f2: &Field) -> bool { f1.name() == f2.name() && Self::datatype_is_semantically_equal(f1.data_type(), f2.data_type()) @@ -777,6 +939,13 @@ pub trait SchemaExt { /// /// It works the same as [`DFSchema::equivalent_names_and_types`]. fn equivalent_names_and_types(&self, other: &Self) -> bool; + + /// Returns true if the two schemas have the same qualified named + /// fields with logically equivalent data types. Returns false otherwise. + /// + /// Use [DFSchema]::equivalent_names_and_types for stricter semantic type + /// equivalence checking. + fn logically_equivalent_names_and_types(&self, other: &Self) -> bool; } impl SchemaExt for Schema { @@ -796,6 +965,23 @@ impl SchemaExt for Schema { ) }) } + + fn logically_equivalent_names_and_types(&self, other: &Self) -> bool { + if self.fields().len() != other.fields().len() { + return false; + } + + self.fields() + .iter() + .zip(other.fields().iter()) + .all(|(f1, f2)| { + f1.name() == f2.name() + && DFSchema::datatype_is_logically_equal( + f1.data_type(), + f2.data_type(), + ) + }) + } } #[cfg(test)] @@ -1308,8 +1494,8 @@ mod tests { DFSchema::new_with_metadata([a, b].to_vec(), HashMap::new()).unwrap(), ); let schema: Schema = df_schema.as_ref().clone().into(); - let a_df = df_schema.fields.get(0).unwrap().field(); - let a_arrow = schema.fields.get(0).unwrap(); + let a_df = df_schema.fields.first().unwrap().field(); + let a_arrow = schema.fields.first().unwrap(); assert_eq!(a_df.metadata(), a_arrow.metadata()) } diff --git a/datafusion/common/src/display/mod.rs b/datafusion/common/src/display/mod.rs index 766b37ce2891..4d1d48bf9fcc 100644 --- a/datafusion/common/src/display/mod.rs +++ b/datafusion/common/src/display/mod.rs @@ -47,6 +47,8 @@ pub enum PlanType { FinalLogicalPlan, /// The initial physical plan, prepared for execution InitialPhysicalPlan, + /// The initial physical plan with stats, prepared for execution + InitialPhysicalPlanWithStats, /// The ExecutionPlan which results from applying an optimizer pass OptimizedPhysicalPlan { /// The name of the optimizer which produced this plan @@ -54,6 +56,8 @@ pub enum PlanType { }, /// The final, fully optimized physical which would be executed FinalPhysicalPlan, + /// The final with stats, fully optimized physical which would be executed + FinalPhysicalPlanWithStats, } impl Display for PlanType { @@ -69,10 +73,14 @@ impl Display for PlanType { } PlanType::FinalLogicalPlan => write!(f, "logical_plan"), PlanType::InitialPhysicalPlan => write!(f, "initial_physical_plan"), + PlanType::InitialPhysicalPlanWithStats => { + write!(f, "initial_physical_plan_with_stats") + } PlanType::OptimizedPhysicalPlan { optimizer_name } => { write!(f, "physical_plan after {optimizer_name}") } PlanType::FinalPhysicalPlan => write!(f, "physical_plan"), + PlanType::FinalPhysicalPlanWithStats => write!(f, "physical_plan_with_stats"), } } } diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index d7a0e1b59dba..978938809c1b 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -47,7 +47,8 @@ pub type GenericError = Box; #[derive(Debug)] pub enum DataFusionError { /// Error returned by arrow. - ArrowError(ArrowError), + /// 2nd argument is for optional backtrace + ArrowError(ArrowError, Option), /// Wraps an error from the Parquet crate #[cfg(feature = "parquet")] ParquetError(ParquetError), @@ -60,7 +61,8 @@ pub enum DataFusionError { /// Error associated to I/O operations and associated traits. IoError(io::Error), /// Error returned when SQL is syntactically incorrect. - SQL(ParserError), + /// 2nd argument is for optional backtrace + SQL(ParserError, Option), /// Error returned on a branch that we know it is possible /// but to which we still have no implementation for. /// Often, these errors are tracked in our issue tracker. @@ -80,7 +82,9 @@ pub enum DataFusionError { Configuration(String), /// This error happens with schema-related errors, such as schema inference not possible /// and non-unique column names. - SchemaError(SchemaError), + /// 2nd argument is for optional backtrace + /// Boxing the optional backtrace to prevent + SchemaError(SchemaError, Box>), /// Error returned during execution of the query. /// Examples include files not found, errors in parsing certain types. Execution(String), @@ -123,34 +127,6 @@ pub enum SchemaError { }, } -/// Create a "field not found" DataFusion::SchemaError -pub fn field_not_found>( - qualifier: Option, - name: &str, - schema: &DFSchema, -) -> DataFusionError { - DataFusionError::SchemaError(SchemaError::FieldNotFound { - field: Box::new(Column::new(qualifier, name)), - valid_fields: schema - .fields() - .iter() - .map(|f| f.qualified_column()) - .collect(), - }) -} - -/// Convenience wrapper over [`field_not_found`] for when there is no qualifier -pub fn unqualified_field_not_found(name: &str, schema: &DFSchema) -> DataFusionError { - DataFusionError::SchemaError(SchemaError::FieldNotFound { - field: Box::new(Column::new_unqualified(name)), - valid_fields: schema - .fields() - .iter() - .map(|f| f.qualified_column()) - .collect(), - }) -} - impl Display for SchemaError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -223,14 +199,14 @@ impl From for DataFusionError { impl From for DataFusionError { fn from(e: ArrowError) -> Self { - DataFusionError::ArrowError(e) + DataFusionError::ArrowError(e, None) } } impl From for ArrowError { fn from(e: DataFusionError) -> Self { match e { - DataFusionError::ArrowError(e) => e, + DataFusionError::ArrowError(e, _) => e, DataFusionError::External(e) => ArrowError::ExternalError(e), other => ArrowError::ExternalError(Box::new(other)), } @@ -267,7 +243,7 @@ impl From for DataFusionError { impl From for DataFusionError { fn from(e: ParserError) -> Self { - DataFusionError::SQL(e) + DataFusionError::SQL(e, None) } } @@ -280,8 +256,9 @@ impl From for DataFusionError { impl Display for DataFusionError { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match *self { - DataFusionError::ArrowError(ref desc) => { - write!(f, "Arrow error: {desc}") + DataFusionError::ArrowError(ref desc, ref backtrace) => { + let backtrace = backtrace.clone().unwrap_or("".to_owned()); + write!(f, "Arrow error: {desc}{backtrace}") } #[cfg(feature = "parquet")] DataFusionError::ParquetError(ref desc) => { @@ -294,8 +271,9 @@ impl Display for DataFusionError { DataFusionError::IoError(ref desc) => { write!(f, "IO error: {desc}") } - DataFusionError::SQL(ref desc) => { - write!(f, "SQL error: {desc:?}") + DataFusionError::SQL(ref desc, ref backtrace) => { + let backtrace: String = backtrace.clone().unwrap_or("".to_owned()); + write!(f, "SQL error: {desc:?}{backtrace}") } DataFusionError::Configuration(ref desc) => { write!(f, "Invalid or Unsupported Configuration: {desc}") @@ -310,8 +288,10 @@ impl Display for DataFusionError { DataFusionError::Plan(ref desc) => { write!(f, "Error during planning: {desc}") } - DataFusionError::SchemaError(ref desc) => { - write!(f, "Schema error: {desc}") + DataFusionError::SchemaError(ref desc, ref backtrace) => { + let backtrace: &str = + &backtrace.as_ref().clone().unwrap_or("".to_owned()); + write!(f, "Schema error: {desc}{backtrace}") } DataFusionError::Execution(ref desc) => { write!(f, "Execution error: {desc}") @@ -339,7 +319,7 @@ impl Display for DataFusionError { impl Error for DataFusionError { fn source(&self) -> Option<&(dyn Error + 'static)> { match self { - DataFusionError::ArrowError(e) => Some(e), + DataFusionError::ArrowError(e, _) => Some(e), #[cfg(feature = "parquet")] DataFusionError::ParquetError(e) => Some(e), #[cfg(feature = "avro")] @@ -347,12 +327,12 @@ impl Error for DataFusionError { #[cfg(feature = "object_store")] DataFusionError::ObjectStore(e) => Some(e), DataFusionError::IoError(e) => Some(e), - DataFusionError::SQL(e) => Some(e), + DataFusionError::SQL(e, _) => Some(e), DataFusionError::NotImplemented(_) => None, DataFusionError::Internal(_) => None, DataFusionError::Configuration(_) => None, DataFusionError::Plan(_) => None, - DataFusionError::SchemaError(e) => Some(e), + DataFusionError::SchemaError(e, _) => Some(e), DataFusionError::Execution(_) => None, DataFusionError::ResourcesExhausted(_) => None, DataFusionError::External(e) => Some(e.as_ref()), @@ -369,7 +349,7 @@ impl From for io::Error { } impl DataFusionError { - const BACK_TRACE_SEP: &str = "\n\nbacktrace: "; + const BACK_TRACE_SEP: &'static str = "\n\nbacktrace: "; /// Get deepest underlying [`DataFusionError`] /// @@ -477,12 +457,25 @@ macro_rules! with_dollar_sign { /// plan_err!("Error {:?}", val) /// plan_err!("Error {val}") /// plan_err!("Error {val:?}") +/// +/// `NAME_ERR` - macro name for wrapping Err(DataFusionError::*) +/// `NAME_DF_ERR` - macro name for wrapping DataFusionError::*. Needed to keep backtrace opportunity +/// in construction where DataFusionError::* used directly, like `map_err`, `ok_or_else`, etc macro_rules! make_error { - ($NAME:ident, $ERR:ident) => { + ($NAME_ERR:ident, $NAME_DF_ERR: ident, $ERR:ident) => { with_dollar_sign! { ($d:tt) => { + /// Macro wraps `$ERR` to add backtrace feature + #[macro_export] + macro_rules! $NAME_DF_ERR { + ($d($d args:expr),*) => { + DataFusionError::$ERR(format!("{}{}", format!($d($d args),*), DataFusionError::get_back_trace()).into()) + } + } + + /// Macro wraps Err(`$ERR`) to add backtrace feature #[macro_export] - macro_rules! $NAME { + macro_rules! $NAME_ERR { ($d($d args:expr),*) => { Err(DataFusionError::$ERR(format!("{}{}", format!($d($d args),*), DataFusionError::get_back_trace()).into())) } @@ -492,31 +485,110 @@ macro_rules! make_error { }; } -// Exposes a macro to create `DataFusionError::Plan` -make_error!(plan_err, Plan); +// Exposes a macro to create `DataFusionError::Plan` with optional backtrace +make_error!(plan_err, plan_datafusion_err, Plan); -// Exposes a macro to create `DataFusionError::Internal` -make_error!(internal_err, Internal); +// Exposes a macro to create `DataFusionError::Internal` with optional backtrace +make_error!(internal_err, internal_datafusion_err, Internal); -// Exposes a macro to create `DataFusionError::NotImplemented` -make_error!(not_impl_err, NotImplemented); +// Exposes a macro to create `DataFusionError::NotImplemented` with optional backtrace +make_error!(not_impl_err, not_impl_datafusion_err, NotImplemented); -// Exposes a macro to create `DataFusionError::Execution` -make_error!(exec_err, Execution); +// Exposes a macro to create `DataFusionError::Execution` with optional backtrace +make_error!(exec_err, exec_datafusion_err, Execution); + +// Exposes a macro to create `DataFusionError::Substrait` with optional backtrace +make_error!(substrait_err, substrait_datafusion_err, Substrait); + +// Exposes a macro to create `DataFusionError::SQL` with optional backtrace +#[macro_export] +macro_rules! sql_datafusion_err { + ($ERR:expr) => { + DataFusionError::SQL($ERR, Some(DataFusionError::get_back_trace())) + }; +} -// Exposes a macro to create `DataFusionError::SQL` +// Exposes a macro to create `Err(DataFusionError::SQL)` with optional backtrace #[macro_export] macro_rules! sql_err { ($ERR:expr) => { - Err(DataFusionError::SQL($ERR)) + Err(datafusion_common::sql_datafusion_err!($ERR)) + }; +} + +// Exposes a macro to create `DataFusionError::ArrowError` with optional backtrace +#[macro_export] +macro_rules! arrow_datafusion_err { + ($ERR:expr) => { + DataFusionError::ArrowError($ERR, Some(DataFusionError::get_back_trace())) + }; +} + +// Exposes a macro to create `Err(DataFusionError::ArrowError)` with optional backtrace +#[macro_export] +macro_rules! arrow_err { + ($ERR:expr) => { + Err(datafusion_common::arrow_datafusion_err!($ERR)) + }; +} + +// Exposes a macro to create `DataFusionError::SchemaError` with optional backtrace +#[macro_export] +macro_rules! schema_datafusion_err { + ($ERR:expr) => { + DataFusionError::SchemaError( + $ERR, + Box::new(Some(DataFusionError::get_back_trace())), + ) + }; +} + +// Exposes a macro to create `Err(DataFusionError::SchemaError)` with optional backtrace +#[macro_export] +macro_rules! schema_err { + ($ERR:expr) => { + Err(DataFusionError::SchemaError( + $ERR, + Box::new(Some(DataFusionError::get_back_trace())), + )) }; } // To avoid compiler error when using macro in the same crate: // macros from the current crate cannot be referred to by absolute paths -pub use exec_err as _exec_err; +pub use internal_datafusion_err as _internal_datafusion_err; pub use internal_err as _internal_err; pub use not_impl_err as _not_impl_err; +pub use plan_err as _plan_err; +pub use schema_err as _schema_err; + +/// Create a "field not found" DataFusion::SchemaError +pub fn field_not_found>( + qualifier: Option, + name: &str, + schema: &DFSchema, +) -> DataFusionError { + schema_datafusion_err!(SchemaError::FieldNotFound { + field: Box::new(Column::new(qualifier, name)), + valid_fields: schema + .fields() + .iter() + .map(|f| f.qualified_column()) + .collect(), + }) +} + +/// Convenience wrapper over [`field_not_found`] for when there is no qualifier +pub fn unqualified_field_not_found(name: &str, schema: &DFSchema) -> DataFusionError { + schema_datafusion_err!(SchemaError::FieldNotFound { + field: Box::new(Column::new_unqualified(name)), + valid_fields: schema + .fields() + .iter() + .map(|f| f.qualified_column()) + .collect(), + }) +} #[cfg(test)] mod test { @@ -550,18 +622,16 @@ mod test { assert_eq!( err.split(DataFusionError::BACK_TRACE_SEP) .collect::>() - .get(0) + .first() .unwrap(), &"Error during planning: Err" ); - assert!( - err.split(DataFusionError::BACK_TRACE_SEP) - .collect::>() - .get(1) - .unwrap() - .len() - > 0 - ); + assert!(!err + .split(DataFusionError::BACK_TRACE_SEP) + .collect::>() + .get(1) + .unwrap() + .is_empty()); } #[cfg(not(feature = "backtrace"))] @@ -585,9 +655,12 @@ mod test { ); do_root_test( - DataFusionError::ArrowError(ArrowError::ExternalError(Box::new( - DataFusionError::ResourcesExhausted("foo".to_string()), - ))), + DataFusionError::ArrowError( + ArrowError::ExternalError(Box::new(DataFusionError::ResourcesExhausted( + "foo".to_string(), + ))), + None, + ), DataFusionError::ResourcesExhausted("foo".to_string()), ); @@ -606,11 +679,12 @@ mod test { ); do_root_test( - DataFusionError::ArrowError(ArrowError::ExternalError(Box::new( - ArrowError::ExternalError(Box::new(DataFusionError::ResourcesExhausted( - "foo".to_string(), - ))), - ))), + DataFusionError::ArrowError( + ArrowError::ExternalError(Box::new(ArrowError::ExternalError(Box::new( + DataFusionError::ResourcesExhausted("foo".to_string()), + )))), + None, + ), DataFusionError::ResourcesExhausted("foo".to_string()), ); diff --git a/datafusion/common/src/file_options/csv_writer.rs b/datafusion/common/src/file_options/csv_writer.rs index b69e778431cc..d6046f0219dd 100644 --- a/datafusion/common/src/file_options/csv_writer.rs +++ b/datafusion/common/src/file_options/csv_writer.rs @@ -37,13 +37,6 @@ pub struct CsvWriterOptions { /// Compression to apply after ArrowWriter serializes RecordBatches. /// This compression is applied by DataFusion not the ArrowWriter itself. pub compression: CompressionTypeVariant, - /// Indicates whether WriterBuilder.has_header() is set to true. - /// This is duplicative as WriterBuilder also stores this information. - /// However, WriterBuilder does not allow public read access to the - /// has_header parameter. - pub has_header: bool, - // TODO: expose a way to read has_header in arrow create - // https://github.com/apache/arrow-rs/issues/4735 } impl CsvWriterOptions { @@ -54,7 +47,6 @@ impl CsvWriterOptions { Self { writer_options, compression, - has_header: true, } } } @@ -65,29 +57,20 @@ impl TryFrom<(&ConfigOptions, &StatementOptions)> for CsvWriterOptions { fn try_from(value: (&ConfigOptions, &StatementOptions)) -> Result { let _configs = value.0; let statement_options = value.1; - let mut has_header = true; let mut builder = WriterBuilder::default(); let mut compression = CompressionTypeVariant::UNCOMPRESSED; for (option, value) in &statement_options.options { builder = match option.to_lowercase().as_str(){ "header" => { - has_header = value.parse() + let has_header = value.parse() .map_err(|_| DataFusionError::Configuration(format!("Unable to parse {value} as bool as required for {option}!")))?; - builder.has_headers(has_header) + builder.with_header(has_header) }, "date_format" => builder.with_date_format(value.to_owned()), "datetime_format" => builder.with_datetime_format(value.to_owned()), "timestamp_format" => builder.with_timestamp_format(value.to_owned()), "time_format" => builder.with_time_format(value.to_owned()), - "rfc3339" => { - let value_bool = value.parse() - .map_err(|_| DataFusionError::Configuration(format!("Unable to parse {value} as bool as required for {option}!")))?; - if value_bool{ - builder.with_rfc3339() - } else{ - builder - } - }, + "rfc3339" => builder, // No-op "null_value" => builder.with_null(value.to_owned()), "compression" => { compression = CompressionTypeVariant::from_str(value.replace('\'', "").as_str())?; @@ -108,11 +91,16 @@ impl TryFrom<(&ConfigOptions, &StatementOptions)> for CsvWriterOptions { ) })?) }, + "quote" | "escape" => { + // https://github.com/apache/arrow-rs/issues/5146 + // These two attributes are only available when reading csv files. + // To avoid error + builder + }, _ => return Err(DataFusionError::Configuration(format!("Found unsupported option {option} with value {value} for CSV format!"))) } } Ok(CsvWriterOptions { - has_header, writer_options: builder, compression, }) diff --git a/datafusion/common/src/file_options/file_type.rs b/datafusion/common/src/file_options/file_type.rs index a07f2e0cb847..97362bdad3cc 100644 --- a/datafusion/common/src/file_options/file_type.rs +++ b/datafusion/common/src/file_options/file_type.rs @@ -103,6 +103,7 @@ impl FromStr for FileType { } #[cfg(test)] +#[cfg(feature = "parquet")] mod tests { use crate::error::DataFusionError; use crate::file_options::FileType; diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index 45b105dfadae..1d661b17eb1c 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -296,6 +296,7 @@ impl Display for FileTypeWriterOptions { } #[cfg(test)] +#[cfg(feature = "parquet")] mod tests { use std::collections::HashMap; @@ -506,6 +507,7 @@ mod tests { } #[test] + // for StatementOptions fn test_writeroptions_csv_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); option_map.insert("header".to_owned(), "true".to_owned()); @@ -523,9 +525,9 @@ mod tests { let csv_options = CsvWriterOptions::try_from((&config, &options))?; let builder = csv_options.writer_options; + assert!(builder.header()); let buff = Vec::new(); let _properties = builder.build(buff); - assert!(csv_options.has_header); assert_eq!(csv_options.compression, CompressionTypeVariant::GZIP); // TODO expand unit test if csv::WriterBuilder allows public read access to properties @@ -533,6 +535,7 @@ mod tests { } #[test] + // for StatementOptions fn test_writeroptions_json_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); option_map.insert("compression".to_owned(), "gzip".to_owned()); diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index 93611114cf8e..80fa023587ee 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -342,7 +342,7 @@ pub(crate) fn parse_version_string(str_setting: &str) -> Result { "2.0" => Ok(WriterVersion::PARQUET_2_0), _ => Err(DataFusionError::Configuration(format!( "Unknown or unsupported parquet writer version {str_setting} \ - valid options are '1.0' and '2.0'" + valid options are 1.0 and 2.0" ))), } } @@ -355,7 +355,7 @@ pub(crate) fn parse_statistics_string(str_setting: &str) -> Result Ok(EnabledStatistics::Page), _ => Err(DataFusionError::Configuration(format!( "Unknown or unsupported parquet statistics setting {str_setting} \ - valid options are 'none', 'page', and 'chunk'" + valid options are none, page, and chunk" ))), } } diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 6a08c4fd3589..1cb1751d713e 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -18,11 +18,16 @@ //! FunctionalDependencies keeps track of functional dependencies //! inside DFSchema. -use crate::{DFSchema, DFSchemaRef, DataFusionError, JoinType, Result}; -use sqlparser::ast::TableConstraint; use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::ops::Deref; +use std::vec::IntoIter; + +use crate::error::_plan_err; +use crate::utils::{merge_and_order_indices, set_difference}; +use crate::{DFSchema, DFSchemaRef, DataFusionError, JoinType, Result}; + +use sqlparser::ast::TableConstraint; /// This object defines a constraint on a table. #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -43,13 +48,15 @@ pub struct Constraints { impl Constraints { /// Create empty constraints pub fn empty() -> Self { - Constraints::new(vec![]) + Constraints::new_unverified(vec![]) } - // This method is private. - // Outside callers can either create empty constraint using `Constraints::empty` API. - // or create constraint from table constraints using `Constraints::new_from_table_constraints` API. - fn new(constraints: Vec) -> Self { + /// Create a new `Constraints` object from the given `constraints`. + /// Users should use the `empty` or `new_from_table_constraints` functions + /// for constructing `Constraints`. This constructor is for internal + /// purposes only and does not check whether the argument is valid. The user + /// is responsible for supplying a valid vector of `Constraint` objects. + pub fn new_unverified(constraints: Vec) -> Self { Self { inner: constraints } } @@ -90,21 +97,21 @@ impl Constraints { Constraint::Unique(indices) }) } - TableConstraint::ForeignKey { .. } => Err(DataFusionError::Plan( - "Foreign key constraints are not currently supported".to_string(), - )), - TableConstraint::Check { .. } => Err(DataFusionError::Plan( - "Check constraints are not currently supported".to_string(), - )), - TableConstraint::Index { .. } => Err(DataFusionError::Plan( - "Indexes are not currently supported".to_string(), - )), - TableConstraint::FulltextOrSpatial { .. } => Err(DataFusionError::Plan( - "Indexes are not currently supported".to_string(), - )), + TableConstraint::ForeignKey { .. } => { + _plan_err!("Foreign key constraints are not currently supported") + } + TableConstraint::Check { .. } => { + _plan_err!("Check constraints are not currently supported") + } + TableConstraint::Index { .. } => { + _plan_err!("Indexes are not currently supported") + } + TableConstraint::FulltextOrSpatial { .. } => { + _plan_err!("Indexes are not currently supported") + } }) .collect::>>()?; - Ok(Constraints::new(constraints)) + Ok(Constraints::new_unverified(constraints)) } /// Check whether constraints is empty @@ -113,6 +120,15 @@ impl Constraints { } } +impl IntoIterator for Constraints { + type Item = Constraint; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.inner.into_iter() + } +} + impl Display for Constraints { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let pk: Vec = self.inner.iter().map(|c| format!("{:?}", c)).collect(); @@ -256,6 +272,29 @@ impl FunctionalDependencies { self.deps.extend(other.deps); } + /// Sanity checks if functional dependencies are valid. For example, if + /// there are 10 fields, we cannot receive any index further than 9. + pub fn is_valid(&self, n_field: usize) -> bool { + self.deps.iter().all( + |FunctionalDependence { + source_indices, + target_indices, + .. + }| { + source_indices + .iter() + .max() + .map(|&max_index| max_index < n_field) + .unwrap_or(true) + && target_indices + .iter() + .max() + .map(|&max_index| max_index < n_field) + .unwrap_or(true) + }, + ) + } + /// Adds the `offset` value to `source_indices` and `target_indices` for /// each functional dependency. pub fn add_offset(&mut self, offset: usize) { @@ -398,6 +437,14 @@ impl FunctionalDependencies { } } +impl Deref for FunctionalDependencies { + type Target = [FunctionalDependence]; + + fn deref(&self) -> &Self::Target { + self.deps.as_slice() + } +} + /// Calculates functional dependencies for aggregate output, when there is a GROUP BY expression. pub fn aggregate_functional_dependencies( aggr_input_schema: &DFSchema, @@ -419,44 +466,56 @@ pub fn aggregate_functional_dependencies( } in &func_dependencies.deps { // Keep source indices in a `HashSet` to prevent duplicate entries: - let mut new_source_indices = HashSet::new(); + let mut new_source_indices = vec![]; + let mut new_source_field_names = vec![]; let source_field_names = source_indices .iter() .map(|&idx| aggr_input_fields[idx].qualified_name()) .collect::>(); + for (idx, group_by_expr_name) in group_by_expr_names.iter().enumerate() { // When one of the input determinant expressions matches with // the GROUP BY expression, add the index of the GROUP BY // expression as a new determinant key: if source_field_names.contains(group_by_expr_name) { - new_source_indices.insert(idx); + new_source_indices.push(idx); + new_source_field_names.push(group_by_expr_name.clone()); } } + let existing_target_indices = + get_target_functional_dependencies(aggr_input_schema, group_by_expr_names); + let new_target_indices = get_target_functional_dependencies( + aggr_input_schema, + &new_source_field_names, + ); + let mode = if existing_target_indices == new_target_indices + && new_target_indices.is_some() + { + // If dependency covers all GROUP BY expressions, mode will be `Single`: + Dependency::Single + } else { + // Otherwise, existing mode is preserved: + *mode + }; // All of the composite indices occur in the GROUP BY expression: if new_source_indices.len() == source_indices.len() { aggregate_func_dependencies.push( FunctionalDependence::new( - new_source_indices.into_iter().collect(), + new_source_indices, target_indices.clone(), *nullable, ) - // input uniqueness stays the same when GROUP BY matches with input functional dependence determinants - .with_mode(*mode), + .with_mode(mode), ); } } + // If we have a single GROUP BY key, we can guarantee uniqueness after // aggregation: if group_by_expr_names.len() == 1 { // If `source_indices` contain 0, delete this functional dependency // as it will be added anyway with mode `Dependency::Single`: - if let Some(idx) = aggregate_func_dependencies - .iter() - .position(|item| item.source_indices.contains(&0)) - { - // Delete the functional dependency that contains zeroth idx: - aggregate_func_dependencies.remove(idx); - } + aggregate_func_dependencies.retain(|item| !item.source_indices.contains(&0)); // Add a new functional dependency associated with the whole table: aggregate_func_dependencies.push( // Use nullable property of the group by expression @@ -504,8 +563,61 @@ pub fn get_target_functional_dependencies( combined_target_indices.extend(target_indices.iter()); } } - (!combined_target_indices.is_empty()) - .then_some(combined_target_indices.iter().cloned().collect::>()) + (!combined_target_indices.is_empty()).then_some({ + let mut result = combined_target_indices.into_iter().collect::>(); + result.sort(); + result + }) +} + +/// Returns indices for the minimal subset of GROUP BY expressions that are +/// functionally equivalent to the original set of GROUP BY expressions. +pub fn get_required_group_by_exprs_indices( + schema: &DFSchema, + group_by_expr_names: &[String], +) -> Option> { + let dependencies = schema.functional_dependencies(); + let field_names = schema + .fields() + .iter() + .map(|item| item.qualified_name()) + .collect::>(); + let mut groupby_expr_indices = group_by_expr_names + .iter() + .map(|group_by_expr_name| { + field_names + .iter() + .position(|field_name| field_name == group_by_expr_name) + }) + .collect::>>()?; + + groupby_expr_indices.sort(); + for FunctionalDependence { + source_indices, + target_indices, + .. + } in &dependencies.deps + { + if source_indices + .iter() + .all(|source_idx| groupby_expr_indices.contains(source_idx)) + { + // If all source indices are among GROUP BY expression indices, we + // can remove target indices from GROUP BY expression indices and + // use source indices instead. + groupby_expr_indices = set_difference(&groupby_expr_indices, target_indices); + groupby_expr_indices = + merge_and_order_indices(groupby_expr_indices, source_indices); + } + } + groupby_expr_indices + .iter() + .map(|idx| { + group_by_expr_names + .iter() + .position(|name| &field_names[*idx] == name) + }) + .collect() } /// Updates entries inside the `entries` vector with their corresponding @@ -534,7 +646,7 @@ mod tests { #[test] fn constraints_iter() { - let constraints = Constraints::new(vec![ + let constraints = Constraints::new_unverified(vec![ Constraint::PrimaryKey(vec![10]), Constraint::Unique(vec![20]), ]); @@ -543,4 +655,21 @@ mod tests { assert_eq!(iter.next(), Some(&Constraint::Unique(vec![20]))); assert_eq!(iter.next(), None); } + + #[test] + fn test_get_updated_id_keys() { + let fund_dependencies = + FunctionalDependencies::new(vec![FunctionalDependence::new( + vec![1], + vec![0, 1, 2], + true, + )]); + let res = fund_dependencies.project_functional_dependencies(&[1, 2], 2); + let expected = FunctionalDependencies::new(vec![FunctionalDependence::new( + vec![0], + vec![0, 1], + true, + )]); + assert_eq!(res, expected); + } } diff --git a/datafusion/physical-expr/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs similarity index 80% rename from datafusion/physical-expr/src/hash_utils.rs rename to datafusion/common/src/hash_utils.rs index 379e0eba5277..8dcc00ca1c29 100644 --- a/datafusion/physical-expr/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -17,19 +17,20 @@ //! Functionality used both on logical and physical plans +use std::sync::Arc; + use ahash::RandomState; use arrow::array::*; use arrow::datatypes::*; use arrow::row::Rows; use arrow::{downcast_dictionary_array, downcast_primitive_array}; use arrow_buffer::i256; -use datafusion_common::{ - cast::{ - as_boolean_array, as_generic_binary_array, as_primitive_array, as_string_array, - }, - internal_err, DataFusionError, Result, + +use crate::cast::{ + as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, + as_primitive_array, as_string_array, as_struct_array, }; -use std::sync::Arc; +use crate::error::{DataFusionError, Result, _internal_err}; // Combines two hashes into one hash #[inline] @@ -51,7 +52,7 @@ fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: } } -pub(crate) trait HashValue { +pub trait HashValue { fn hash_one(&self, state: &RandomState) -> u64; } @@ -207,6 +208,32 @@ fn hash_dictionary( Ok(()) } +fn hash_struct_array( + array: &StructArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let nulls = array.nulls(); + let row_len = array.len(); + + let valid_row_indices: Vec = if let Some(nulls) = nulls { + nulls.valid_indices().collect() + } else { + (0..row_len).collect() + }; + + // Create hashes for each row that combines the hashes over all the column at that row. + let mut values_hashes = vec![0u64; row_len]; + create_hashes(array.columns(), random_state, &mut values_hashes)?; + + for i in valid_row_indices { + let hash = &mut hashes_buffer[i]; + *hash = combine_hashes(*hash, values_hashes[i]); + } + + Ok(()) +} + fn hash_list_array( array: &GenericListArray, random_state: &RandomState, @@ -327,17 +354,21 @@ pub fn create_hashes<'a>( array => hash_dictionary(array, random_state, hashes_buffer, rehash)?, _ => unreachable!() } + DataType::Struct(_) => { + let array = as_struct_array(array)?; + hash_struct_array(array, random_state, hashes_buffer)?; + } DataType::List(_) => { - let array = as_list_array(array); + let array = as_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } DataType::LargeList(_) => { - let array = as_large_list_array(array); + let array = as_large_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } _ => { // This is internal because we should have caught this before. - return internal_err!( + return _internal_err!( "Unsupported data type in hasher: {}", col.data_type() ); @@ -515,6 +546,91 @@ mod tests { assert_eq!(hashes[2], hashes[3]); } + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_struct_arrays() { + use arrow_buffer::Buffer; + + let boolarr = Arc::new(BooleanArray::from(vec![ + false, false, true, true, true, true, + ])); + let i32arr = Arc::new(Int32Array::from(vec![10, 10, 20, 20, 30, 31])); + + let struct_array = StructArray::from(( + vec![ + ( + Arc::new(Field::new("bool", DataType::Boolean, false)), + boolarr.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("i32", DataType::Int32, false)), + i32arr.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("i32", DataType::Int32, false)), + i32arr.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("bool", DataType::Boolean, false)), + boolarr.clone() as ArrayRef, + ), + ], + Buffer::from(&[0b001011]), + )); + + assert!(struct_array.is_valid(0)); + assert!(struct_array.is_valid(1)); + assert!(struct_array.is_null(2)); + assert!(struct_array.is_valid(3)); + assert!(struct_array.is_null(4)); + assert!(struct_array.is_null(5)); + + let array = Arc::new(struct_array) as ArrayRef; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array.len()]; + create_hashes(&[array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); + // same value but the third row ( hashes[2] ) is null + assert_ne!(hashes[2], hashes[3]); + // different values but both are null + assert_eq!(hashes[4], hashes[5]); + } + + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_struct_arrays_more_column_than_row() { + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("bool", DataType::Boolean, false)), + Arc::new(BooleanArray::from(vec![false, false])) as ArrayRef, + ), + ( + Arc::new(Field::new("i32-1", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![10, 10])) as ArrayRef, + ), + ( + Arc::new(Field::new("i32-2", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![10, 10])) as ArrayRef, + ), + ( + Arc::new(Field::new("i32-3", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![10, 10])) as ArrayRef, + ), + ]); + + assert!(struct_array.is_valid(0)); + assert!(struct_array.is_valid(1)); + + let array = Arc::new(struct_array) as ArrayRef; + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array.len()]; + create_hashes(&[array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index 8d4657f1dc56..0a00a57ba45f 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`JoinType`] and [`JoinConstraint`] +//! Defines the [`JoinType`], [`JoinConstraint`] and [`JoinSide`] types. use std::{ fmt::{self, Display, Formatter}, @@ -95,3 +95,32 @@ pub enum JoinConstraint { /// Join USING Using, } + +impl Display for JoinSide { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + JoinSide::Left => write!(f, "left"), + JoinSide::Right => write!(f, "right"), + } + } +} + +/// Join side. +/// Stores the referred table side during calculations +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum JoinSide { + /// Left side of the join + Left, + /// Right side of the join + Right, +} + +impl JoinSide { + /// Inverse the join side + pub fn negate(&self) -> Self { + match self { + JoinSide::Left => JoinSide::Right, + JoinSide::Right => JoinSide::Left, + } + } +} diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index eeb5b2681370..ed547782e4a5 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -15,46 +15,53 @@ // specific language governing permissions and limitations // under the License. -pub mod alias; -pub mod cast; mod column; -pub mod config; mod dfschema; -pub mod display; mod error; -pub mod file_options; -pub mod format; mod functional_dependencies; mod join_type; -pub mod parsers; +mod param_value; #[cfg(feature = "pyarrow")] mod pyarrow; -pub mod scalar; mod schema_reference; -pub mod stats; mod table_reference; +mod unnest; + +pub mod alias; +pub mod cast; +pub mod config; +pub mod display; +pub mod file_options; +pub mod format; +pub mod hash_utils; +pub mod parsers; +pub mod rounding; +pub mod scalar; +pub mod stats; pub mod test_util; pub mod tree_node; -mod unnest; pub mod utils; +/// Reexport arrow crate +pub use arrow; pub use column::Column; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, SchemaExt, ToDFSchema}; pub use error::{ field_not_found, unqualified_field_not_found, DataFusionError, Result, SchemaError, SharedResult, }; - pub use file_options::file_type::{ FileType, GetExt, DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, DEFAULT_CSV_EXTENSION, DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, }; pub use file_options::FileTypeWriterOptions; pub use functional_dependencies::{ - aggregate_functional_dependencies, get_target_functional_dependencies, Constraints, - Dependency, FunctionalDependence, FunctionalDependencies, + aggregate_functional_dependencies, get_required_group_by_exprs_indices, + get_target_functional_dependencies, Constraint, Constraints, Dependency, + FunctionalDependence, FunctionalDependencies, }; -pub use join_type::{JoinConstraint, JoinType}; +pub use join_type::{JoinConstraint, JoinSide, JoinType}; +pub use param_value::ParamValues; pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::{OwnedSchemaReference, SchemaReference}; pub use stats::{ColumnStatistics, Statistics}; @@ -62,9 +69,6 @@ pub use table_reference::{OwnedTableReference, ResolvedTableReference, TableRefe pub use unnest::UnnestOptions; pub use utils::project_schema; -/// Reexport arrow crate -pub use arrow; - /// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is /// not possible. In normal usage of DataFusion the downcast should always succeed. /// diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs new file mode 100644 index 000000000000..3fe2ba99ab83 --- /dev/null +++ b/datafusion/common/src/param_value.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::{_internal_err, _plan_err}; +use crate::{DataFusionError, Result, ScalarValue}; +use arrow_schema::DataType; +use std::collections::HashMap; + +/// The parameter value corresponding to the placeholder +#[derive(Debug, Clone)] +pub enum ParamValues { + /// For positional query parameters, like `SELECT * FROM test WHERE a > $1 AND b = $2` + List(Vec), + /// For named query parameters, like `SELECT * FROM test WHERE a > $foo AND b = $goo` + Map(HashMap), +} + +impl ParamValues { + /// Verify parameter list length and type + pub fn verify(&self, expect: &[DataType]) -> Result<()> { + match self { + ParamValues::List(list) => { + // Verify if the number of params matches the number of values + if expect.len() != list.len() { + return _plan_err!( + "Expected {} parameters, got {}", + expect.len(), + list.len() + ); + } + + // Verify if the types of the params matches the types of the values + let iter = expect.iter().zip(list.iter()); + for (i, (param_type, value)) in iter.enumerate() { + if *param_type != value.data_type() { + return _plan_err!( + "Expected parameter of type {:?}, got {:?} at index {}", + param_type, + value.data_type(), + i + ); + } + } + Ok(()) + } + ParamValues::Map(_) => { + // If it is a named query, variables can be reused, + // but the lengths are not necessarily equal + Ok(()) + } + } + } + + pub fn get_placeholders_with_values( + &self, + id: &str, + data_type: Option<&DataType>, + ) -> Result { + match self { + ParamValues::List(list) => { + if id.is_empty() { + return _plan_err!("Empty placeholder id"); + } + // convert id (in format $1, $2, ..) to idx (0, 1, ..) + let idx = id[1..] + .parse::() + .map_err(|e| { + DataFusionError::Internal(format!( + "Failed to parse placeholder id: {e}" + )) + })? + .checked_sub(1); + // value at the idx-th position in param_values should be the value for the placeholder + let value = idx.and_then(|idx| list.get(idx)).ok_or_else(|| { + DataFusionError::Internal(format!( + "No value found for placeholder with id {id}" + )) + })?; + // check if the data type of the value matches the data type of the placeholder + if Some(&value.data_type()) != data_type { + return _internal_err!( + "Placeholder value type mismatch: expected {:?}, got {:?}", + data_type, + value.data_type() + ); + } + Ok(value.clone()) + } + ParamValues::Map(map) => { + // convert name (in format $a, $b, ..) to mapped values (a, b, ..) + let name = &id[1..]; + // value at the name position in param_values should be the value for the placeholder + let value = map.get(name).ok_or_else(|| { + DataFusionError::Internal(format!( + "No value found for placeholder with name {id}" + )) + })?; + // check if the data type of the value matches the data type of the placeholder + if Some(&value.data_type()) != data_type { + return _internal_err!( + "Placeholder value type mismatch: expected {:?}, got {:?}", + data_type, + value.data_type() + ); + } + Ok(value.clone()) + } + } + } +} + +impl From> for ParamValues { + fn from(value: Vec) -> Self { + Self::List(value) + } +} + +impl From> for ParamValues +where + K: Into, +{ + fn from(value: Vec<(K, ScalarValue)>) -> Self { + let value: HashMap = + value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + Self::Map(value) + } +} + +impl From> for ParamValues +where + K: Into, +{ + fn from(value: HashMap) -> Self { + let value: HashMap = + value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + Self::Map(value) + } +} diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs index d18782e037ae..f4356477532f 100644 --- a/datafusion/common/src/pyarrow.rs +++ b/datafusion/common/src/pyarrow.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! PyArrow +//! Conversions between PyArrow and DataFusion types use arrow::array::ArrayData; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; @@ -54,7 +54,7 @@ impl FromPyArrow for ScalarValue { impl ToPyArrow for ScalarValue { fn to_pyarrow(&self, py: Python) -> PyResult { - let array = self.to_array(); + let array = self.to_array()?; // convert to pyarrow array using C data interface let pyarray = array.to_data().to_pyarrow(py)?; let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?; @@ -94,10 +94,11 @@ mod tests { Some(locals), ) .expect("Couldn't get python info"); - let executable: String = - locals.get_item("executable").unwrap().extract().unwrap(); - let python_path: Vec<&str> = - locals.get_item("python_path").unwrap().extract().unwrap(); + let executable = locals.get_item("executable").unwrap().unwrap(); + let executable: String = executable.extract().unwrap(); + + let python_path = locals.get_item("python_path").unwrap().unwrap(); + let python_path: Vec<&str> = python_path.extract().unwrap(); panic!("pyarrow not found\nExecutable: {executable}\nPython path: {python_path:?}\n\ HINT: try `pip install pyarrow`\n\ @@ -118,7 +119,7 @@ mod tests { ScalarValue::Boolean(Some(true)), ScalarValue::Int32(Some(23)), ScalarValue::Float64(Some(12.34)), - ScalarValue::Utf8(Some("Hello!".to_string())), + ScalarValue::from("Hello!"), ScalarValue::Date32(Some(1234)), ]; diff --git a/datafusion/physical-expr/src/intervals/rounding.rs b/datafusion/common/src/rounding.rs similarity index 98% rename from datafusion/physical-expr/src/intervals/rounding.rs rename to datafusion/common/src/rounding.rs index c1172fba9152..413067ecd61e 100644 --- a/datafusion/physical-expr/src/intervals/rounding.rs +++ b/datafusion/common/src/rounding.rs @@ -22,8 +22,8 @@ use std::ops::{Add, BitAnd, Sub}; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use crate::Result; +use crate::ScalarValue; // Define constants for ARM #[cfg(all(target_arch = "aarch64", not(target_os = "windows")))] @@ -162,7 +162,7 @@ impl FloatBits for f64 { /// # Examples /// /// ``` -/// use datafusion_physical_expr::intervals::rounding::next_up; +/// use datafusion_common::rounding::next_up; /// /// let f: f32 = 1.0; /// let next_f = next_up(f); @@ -195,7 +195,7 @@ pub fn next_up(float: F) -> F { /// # Examples /// /// ``` -/// use datafusion_physical_expr::intervals::rounding::next_down; +/// use datafusion_common::rounding::next_down; /// /// let f: f32 = 1.0; /// let next_f = next_down(f); diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 32343b98fa24..48878aa9bd99 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -24,32 +24,78 @@ use std::convert::{Infallible, TryInto}; use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; +use crate::arrow_datafusion_err; use crate::cast::{ as_decimal128_array, as_decimal256_array, as_dictionary_array, - as_fixed_size_binary_array, as_fixed_size_list_array, as_list_array, as_struct_array, + as_fixed_size_binary_array, as_fixed_size_list_array, as_struct_array, }; use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; -use arrow::buffer::NullBuffer; +use crate::hash_utils::create_hashes; +use crate::utils::{array_into_large_list_array, array_into_list_array}; use arrow::compute::kernels::numeric::*; -use arrow::compute::nullif; -use arrow::datatypes::{i256, FieldRef, Fields, SchemaBuilder}; +use arrow::datatypes::{i256, Fields, SchemaBuilder}; +use arrow::util::display::{ArrayFormatter, FormatOptions}; use arrow::{ array::*, compute::kernels::cast::{cast_with_options, CastOptions}, datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, - IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, - DECIMAL128_MAX_PRECISION, + ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, Int16Type, + Int32Type, Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, + IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, DECIMAL128_MAX_PRECISION, }, }; +use arrow_array::cast::as_list_array; +use arrow_array::types::ArrowTimestampType; use arrow_array::{ArrowNativeTypeOp, Scalar}; -/// Represents a dynamically typed, nullable single value. -/// This is the single-valued counter-part to arrow's [`Array`]. +/// A dynamically typed, nullable single value, (the single-valued counter-part +/// to arrow's [`Array`]) /// +/// # Performance +/// +/// In general, please use arrow [`Array`]s rather than [`ScalarValue`] whenever +/// possible, as it is far more efficient for multiple values. +/// +/// # Example +/// ``` +/// # use datafusion_common::ScalarValue; +/// // Create single scalar value for an Int32 value +/// let s1 = ScalarValue::Int32(Some(10)); +/// +/// // You can also create values using the From impl: +/// let s2 = ScalarValue::from(10i32); +/// assert_eq!(s1, s2); +/// ``` +/// +/// # Null Handling +/// +/// `ScalarValue` represents null values in the same way as Arrow. Nulls are +/// "typed" in the sense that a null value in an [`Int32Array`] is different +/// than a null value in a [`Float64Array`], and is different than the values in +/// a [`NullArray`]. +/// +/// ``` +/// # fn main() -> datafusion_common::Result<()> { +/// # use std::collections::hash_set::Difference; +/// # use datafusion_common::ScalarValue; +/// # use arrow::datatypes::DataType; +/// // You can create a 'null' Int32 value directly: +/// let s1 = ScalarValue::Int32(None); +/// +/// // You can also create a null value for a given datatype: +/// let s2 = ScalarValue::try_from(&DataType::Int32)?; +/// assert_eq!(s1, s2); +/// +/// // Note that this is DIFFERENT than a `ScalarValue::Null` +/// let s3 = ScalarValue::Null; +/// assert_ne!(s1, s3); +/// # Ok(()) +/// # } +/// ``` +/// +/// # Further Reading /// See [datatypes](https://arrow.apache.org/docs/python/api/datatypes.html) for /// details on datatypes and the [format](https://github.com/apache/arrow/blob/master/format/Schema.fbs#L354-L375) /// for the definitive reference. @@ -93,10 +139,16 @@ pub enum ScalarValue { FixedSizeBinary(i32, Option>), /// large binary LargeBinary(Option>), - /// Fixed size list of nested ScalarValue - Fixedsizelist(Option>, FieldRef, i32), - /// List of nested ScalarValue - List(Option>, FieldRef), + /// Fixed size list scalar. + /// + /// The array must be a FixedSizeListArray with length 1. + FixedSizeList(ArrayRef), + /// Represents a single element of a [`ListArray`] as an [`ArrayRef`] + /// + /// The array must be a ListArray with length 1. + List(ArrayRef), + /// The array must be a LargeListArray with length 1. + LargeList(ArrayRef), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), /// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01 @@ -194,12 +246,12 @@ impl PartialEq for ScalarValue { (FixedSizeBinary(_, _), _) => false, (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2), (LargeBinary(_), _) => false, - (Fixedsizelist(v1, t1, l1), Fixedsizelist(v2, t2, l2)) => { - v1.eq(v2) && t1.eq(t2) && l1.eq(l2) - } - (Fixedsizelist(_, _, _), _) => false, - (List(v1, t1), List(v2, t2)) => v1.eq(v2) && t1.eq(t2), - (List(_, _), _) => false, + (FixedSizeList(v1), FixedSizeList(v2)) => v1.eq(v2), + (FixedSizeList(_), _) => false, + (List(v1), List(v2)) => v1.eq(v2), + (List(_), _) => false, + (LargeList(v1), LargeList(v2)) => v1.eq(v2), + (LargeList(_), _) => false, (Date32(v1), Date32(v2)) => v1.eq(v2), (Date32(_), _) => false, (Date64(v1), Date64(v2)) => v1.eq(v2), @@ -308,22 +360,47 @@ impl PartialOrd for ScalarValue { (FixedSizeBinary(_, _), _) => None, (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), (LargeBinary(_), _) => None, - (Fixedsizelist(v1, t1, l1), Fixedsizelist(v2, t2, l2)) => { - if t1.eq(t2) && l1.eq(l2) { - v1.partial_cmp(v2) - } else { - None + (List(arr1), List(arr2)) + | (FixedSizeList(arr1), FixedSizeList(arr2)) + | (LargeList(arr1), LargeList(arr2)) => { + // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 + assert_eq!(arr1.len(), 1); + assert_eq!(arr2.len(), 1); + + if arr1.data_type() != arr2.data_type() { + return None; } - } - (Fixedsizelist(_, _, _), _) => None, - (List(v1, t1), List(v2, t2)) => { - if t1.eq(t2) { - v1.partial_cmp(v2) - } else { - None + + fn first_array_for_list(arr: &ArrayRef) -> ArrayRef { + if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_fixed_size_list_opt() { + arr.value(0) + } else { + unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") + } + } + + let arr1 = first_array_for_list(arr1); + let arr2 = first_array_for_list(arr2); + + let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); + } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); + } } + + Some(Ordering::Equal) } - (List(_, _), _) => None, + (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), (Date32(_), _) => None, (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), @@ -406,6 +483,10 @@ macro_rules! hash_float_value { hash_float_value!((f64, u64), (f32, u32)); // manual implementation of `Hash` +// +// # Panics +// +// Panics if there is an error when creating hash values for rows impl std::hash::Hash for ScalarValue { fn hash(&self, state: &mut H) { use ScalarValue::*; @@ -436,14 +517,14 @@ impl std::hash::Hash for ScalarValue { Binary(v) => v.hash(state), FixedSizeBinary(_, v) => v.hash(state), LargeBinary(v) => v.hash(state), - Fixedsizelist(v, t, l) => { - v.hash(state); - t.hash(state); - l.hash(state); - } - List(v, t) => { - v.hash(state); - t.hash(state); + List(arr) | LargeList(arr) | FixedSizeList(arr) => { + let arrays = vec![arr.to_owned()]; + let hashes_buffer = &mut vec![0; arr.len()]; + let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); + let hashes = + create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); + // Hash back to std::hash::Hasher + hashes.hash(state); } Date32(v) => v.hash(state), Date64(v) => v.hash(state), @@ -476,15 +557,19 @@ impl std::hash::Hash for ScalarValue { } } -/// return a reference to the values array and the index into it for a +/// Return a reference to the values array and the index into it for a /// dictionary array +/// +/// # Errors +/// +/// Errors if the array cannot be downcasted to DictionaryArray #[inline] pub fn get_dict_value( array: &dyn Array, index: usize, -) -> (&ArrayRef, Option) { - let dict_array = as_dictionary_array::(array).unwrap(); - (dict_array.values(), dict_array.key(index)) +) -> Result<(&ArrayRef, Option)> { + let dict_array = as_dictionary_array::(array)?; + Ok((dict_array.values(), dict_array.key(index))) } /// Create a dictionary array representing `value` repeated `size` @@ -492,9 +577,9 @@ pub fn get_dict_value( fn dict_from_scalar( value: &ScalarValue, size: usize, -) -> ArrayRef { +) -> Result { // values array is one element long (the value) - let values_array = value.to_array_of_size(1); + let values_array = value.to_array_of_size(1)?; // Create a key array with `size` elements, each of 0 let key_array: PrimitiveArray = std::iter::repeat(Some(K::default_value())) @@ -506,11 +591,9 @@ fn dict_from_scalar( // Note: this path could be made faster by using the ArrayData // APIs and skipping validation, if it every comes up in // performance traces. - Arc::new( - DictionaryArray::::try_new(key_array, values_array) - // should always be valid by construction above - .expect("Can not construct dictionary array"), - ) + Ok(Arc::new( + DictionaryArray::::try_new(key_array, values_array)?, // should always be valid by construction above + )) } /// Create a dictionary array representing all the values in values @@ -549,152 +632,44 @@ fn dict_from_values( macro_rules! typed_cast_tz { ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TZ:expr) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); - ScalarValue::$SCALAR( + use std::any::type_name; + let array = $array + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast value to {}", + type_name::<$ARRAYTYPE>() + )) + })?; + Ok::(ScalarValue::$SCALAR( match array.is_null($index) { true => None, false => Some(array.value($index).into()), }, $TZ.clone(), - ) + )) }}; } macro_rules! typed_cast { ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); - ScalarValue::$SCALAR(match array.is_null($index) { - true => None, - false => Some(array.value($index).into()), - }) - }}; -} - -// keep until https://github.com/apache/arrow-rs/issues/2054 is finished -macro_rules! build_list { - ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ - match $VALUES { - // the return on the macro is necessary, to short-circuit and return ArrayRef - None => { - return new_null_array( - &DataType::List(Arc::new(Field::new( - "item", - DataType::$SCALAR_TY, - true, - ))), - $SIZE, - ) - } - Some(values) => { - build_values_list!($VALUE_BUILDER_TY, $SCALAR_TY, values, $SIZE) - } - } - }}; -} - -macro_rules! build_timestamp_list { - ($TIME_UNIT:expr, $TIME_ZONE:expr, $VALUES:expr, $SIZE:expr) => {{ - match $VALUES { - // the return on the macro is necessary, to short-circuit and return ArrayRef - None => { - return new_null_array( - &DataType::List(Arc::new(Field::new( - "item", - DataType::Timestamp($TIME_UNIT, $TIME_ZONE), - true, - ))), - $SIZE, - ) - } - Some(values) => match $TIME_UNIT { - TimeUnit::Second => { - build_values_list_tz!( - TimestampSecondBuilder, - TimestampSecond, - values, - $SIZE - ) - } - TimeUnit::Millisecond => build_values_list_tz!( - TimestampMillisecondBuilder, - TimestampMillisecond, - values, - $SIZE - ), - TimeUnit::Microsecond => build_values_list_tz!( - TimestampMicrosecondBuilder, - TimestampMicrosecond, - values, - $SIZE - ), - TimeUnit::Nanosecond => build_values_list_tz!( - TimestampNanosecondBuilder, - TimestampNanosecond, - values, - $SIZE - ), + use std::any::type_name; + let array = $array + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast value to {}", + type_name::<$ARRAYTYPE>() + )) + })?; + Ok::(ScalarValue::$SCALAR( + match array.is_null($index) { + true => None, + false => Some(array.value($index).into()), }, - } - }}; -} - -macro_rules! new_builder { - (StringBuilder, $len:expr) => { - StringBuilder::new() - }; - (LargeStringBuilder, $len:expr) => { - LargeStringBuilder::new() - }; - ($el:ident, $len:expr) => {{ - <$el>::with_capacity($len) - }}; -} - -macro_rules! build_values_list { - ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ - let builder = new_builder!($VALUE_BUILDER_TY, $VALUES.len()); - let mut builder = ListBuilder::new(builder); - - for _ in 0..$SIZE { - for scalar_value in $VALUES { - match scalar_value { - ScalarValue::$SCALAR_TY(Some(v)) => { - builder.values().append_value(v.clone()); - } - ScalarValue::$SCALAR_TY(None) => { - builder.values().append_null(); - } - _ => panic!("Incompatible ScalarValue for list"), - }; - } - builder.append(true); - } - - builder.finish() - }}; -} - -macro_rules! build_values_list_tz { - ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ - let mut builder = - ListBuilder::new($VALUE_BUILDER_TY::with_capacity($VALUES.len())); - - for _ in 0..$SIZE { - for scalar_value in $VALUES { - match scalar_value { - ScalarValue::$SCALAR_TY(Some(v), _) => { - builder.values().append_value(v.clone()); - } - ScalarValue::$SCALAR_TY(None, _) => { - builder.values().append_null(); - } - _ => panic!("Incompatible ScalarValue for list"), - }; - } - builder.append(true); - } - - builder.finish() + )) }}; } @@ -726,17 +701,26 @@ macro_rules! build_timestamp_array_from_option { macro_rules! eq_array_primitive { ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + use std::any::type_name; + let array = $array + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast value to {}", + type_name::<$ARRAYTYPE>() + )) + })?; let is_valid = array.is_valid($index); - match $VALUE { + Ok::(match $VALUE { Some(val) => is_valid && &array.value($index) == val, None => !is_valid, - } + }) }}; } impl ScalarValue { - /// Create a [`ScalarValue`] with the provided value and datatype + /// Create a [`Result`] with the provided value and datatype /// /// # Panics /// @@ -744,13 +728,13 @@ impl ScalarValue { pub fn new_primitive( a: Option, d: &DataType, - ) -> Self { + ) -> Result { match a { - None => d.try_into().unwrap(), + None => d.try_into(), Some(v) => { let array = PrimitiveArray::::new(vec![v].into(), None) .with_data_type(d.clone()); - Self::try_from_array(&array, 0).unwrap() + Self::try_from_array(&array, 0) } } } @@ -768,7 +752,7 @@ impl ScalarValue { /// Returns a [`ScalarValue::Utf8`] representing `val` pub fn new_utf8(val: impl Into) -> Self { - ScalarValue::Utf8(Some(val.into())) + ScalarValue::from(val.into()) } /// Returns a [`ScalarValue::IntervalYearMonth`] representing @@ -792,9 +776,18 @@ impl ScalarValue { ScalarValue::IntervalMonthDayNano(Some(val)) } - /// Create a new nullable ScalarValue::List with the specified child_type - pub fn new_list(scalars: Option>, child_type: DataType) -> Self { - Self::List(scalars, Arc::new(Field::new("item", child_type, true))) + /// Returns a [`ScalarValue`] representing + /// `value` and `tz_opt` timezone + pub fn new_timestamp( + value: Option, + tz_opt: Option>, + ) -> Self { + match T::UNIT { + TimeUnit::Second => ScalarValue::TimestampSecond(value, tz_opt), + TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(value, tz_opt), + TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(value, tz_opt), + TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(value, tz_opt), + } } /// Create a zero value in the given type. @@ -833,15 +826,15 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) => { ScalarValue::IntervalMonthDayNano(Some(0)) } - DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(None), + DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(Some(0)), DataType::Duration(TimeUnit::Millisecond) => { - ScalarValue::DurationMillisecond(None) + ScalarValue::DurationMillisecond(Some(0)) } DataType::Duration(TimeUnit::Microsecond) => { - ScalarValue::DurationMicrosecond(None) + ScalarValue::DurationMicrosecond(Some(0)) } DataType::Duration(TimeUnit::Nanosecond) => { - ScalarValue::DurationNanosecond(None) + ScalarValue::DurationNanosecond(Some(0)) } _ => { return _not_impl_err!( @@ -949,15 +942,9 @@ impl ScalarValue { ScalarValue::Binary(_) => DataType::Binary, ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz), ScalarValue::LargeBinary(_) => DataType::LargeBinary, - ScalarValue::Fixedsizelist(_, field, length) => DataType::FixedSizeList( - Arc::new(Field::new("item", field.data_type().clone(), true)), - *length, - ), - ScalarValue::List(_, field) => DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - ))), + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::Time32Second(_) => DataType::Time32(TimeUnit::Second), @@ -1004,7 +991,8 @@ impl ScalarValue { | ScalarValue::Int16(None) | ScalarValue::Int32(None) | ScalarValue::Int64(None) - | ScalarValue::Float32(None) => Ok(self.clone()), + | ScalarValue::Float32(None) + | ScalarValue::Float64(None) => Ok(self.clone()), ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))), ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))), ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(-v))), @@ -1030,6 +1018,18 @@ impl ScalarValue { ScalarValue::Decimal256(Some(v), precision, scale) => Ok( ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision, *scale), ), + ScalarValue::TimestampSecond(Some(v), tz) => { + Ok(ScalarValue::TimestampSecond(Some(-v), tz.clone())) + } + ScalarValue::TimestampNanosecond(Some(v), tz) => { + Ok(ScalarValue::TimestampNanosecond(Some(-v), tz.clone())) + } + ScalarValue::TimestampMicrosecond(Some(v), tz) => { + Ok(ScalarValue::TimestampMicrosecond(Some(-v), tz.clone())) + } + ScalarValue::TimestampMillisecond(Some(v), tz) => { + Ok(ScalarValue::TimestampMillisecond(Some(-v), tz.clone())) + } value => _internal_err!( "Can not run arithmetic negative on scalar value {value:?}" ), @@ -1041,7 +1041,7 @@ impl ScalarValue { /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code /// should operate on Arrays directly, using vectorized array kernels pub fn add>(&self, other: T) -> Result { - let r = add_wrapping(&self.to_scalar(), &other.borrow().to_scalar())?; + let r = add_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } /// Checked addition of `ScalarValue` @@ -1049,7 +1049,7 @@ impl ScalarValue { /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code /// should operate on Arrays directly, using vectorized array kernels pub fn add_checked>(&self, other: T) -> Result { - let r = add(&self.to_scalar(), &other.borrow().to_scalar())?; + let r = add(&self.to_scalar()?, &other.borrow().to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } @@ -1058,7 +1058,7 @@ impl ScalarValue { /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code /// should operate on Arrays directly, using vectorized array kernels pub fn sub>(&self, other: T) -> Result { - let r = sub_wrapping(&self.to_scalar(), &other.borrow().to_scalar())?; + let r = sub_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } @@ -1067,7 +1067,49 @@ impl ScalarValue { /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code /// should operate on Arrays directly, using vectorized array kernels pub fn sub_checked>(&self, other: T) -> Result { - let r = sub(&self.to_scalar(), &other.borrow().to_scalar())?; + let r = sub(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) + } + + /// Wrapping multiplication of `ScalarValue` + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn mul>(&self, other: T) -> Result { + let r = mul_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) + } + + /// Checked multiplication of `ScalarValue` + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn mul_checked>(&self, other: T) -> Result { + let r = mul(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) + } + + /// Performs `lhs / rhs` + /// + /// Overflow or division by zero will result in an error, with exception to + /// floating point numbers, which instead follow the IEEE 754 rules. + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn div>(&self, other: T) -> Result { + let r = div(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) + } + + /// Performs `lhs % rhs` + /// + /// Overflow or division by zero will result in an error, with exception to + /// floating point numbers, which instead follow the IEEE 754 rules. + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn rem>(&self, other: T) -> Result { + let r = rem(&self.to_scalar()?, &other.borrow().to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } @@ -1103,8 +1145,11 @@ impl ScalarValue { ScalarValue::Binary(v) => v.is_none(), ScalarValue::FixedSizeBinary(_, v) => v.is_none(), ScalarValue::LargeBinary(v) => v.is_none(), - ScalarValue::Fixedsizelist(v, ..) => v.is_none(), - ScalarValue::List(v, _) => v.is_none(), + // arr.len() should be 1 for a list scalar, but we don't seem to + // enforce that anywhere, so we still check against array length. + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), ScalarValue::Date32(v) => v.is_none(), ScalarValue::Date64(v) => v.is_none(), ScalarValue::Time32Second(v) => v.is_none(), @@ -1156,7 +1201,11 @@ impl ScalarValue { } /// Converts a scalar value into an 1-row array. - pub fn to_array(&self) -> ArrayRef { + /// + /// # Errors + /// + /// Errors if the ScalarValue cannot be converted into a 1-row array + pub fn to_array(&self) -> Result { self.to_array_of_size(1) } @@ -1165,6 +1214,10 @@ impl ScalarValue { /// /// This can be used to call arrow compute kernels such as `lt` /// + /// # Errors + /// + /// Errors if the ScalarValue cannot be converted into a 1-row array + /// /// # Example /// ``` /// use datafusion_common::ScalarValue; @@ -1175,7 +1228,7 @@ impl ScalarValue { /// /// let result = arrow::compute::kernels::cmp::lt( /// &arr, - /// &five.to_scalar(), + /// &five.to_scalar().unwrap(), /// ).unwrap(); /// /// let expected = BooleanArray::from(vec![ @@ -1188,16 +1241,21 @@ impl ScalarValue { /// assert_eq!(&result, &expected); /// ``` /// [`Datum`]: arrow_array::Datum - pub fn to_scalar(&self) -> Scalar { - Scalar::new(self.to_array_of_size(1)) + pub fn to_scalar(&self) -> Result> { + Ok(Scalar::new(self.to_array_of_size(1)?)) } /// Converts an iterator of references [`ScalarValue`] into an [`ArrayRef`] - /// corresponding to those values. For example, + /// corresponding to those values. For example, an iterator of + /// [`ScalarValue::Int32`] would be converted to an [`Int32Array`]. /// /// Returns an error if the iterator is empty or if the /// [`ScalarValue`]s are not all the same type /// + /// # Panics + /// + /// Panics if `self` is a dictionary with invalid key type + /// /// # Example /// ``` /// use datafusion_common::ScalarValue; @@ -1302,70 +1360,36 @@ impl ScalarValue { }}; } - macro_rules! build_array_list_primitive { - ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{ - Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( - scalars.into_iter().map(|x| match x { - ScalarValue::List(xs, _) => xs.map(|x| { - x.iter().map(|x| match x { - ScalarValue::$SCALAR_TY(i) => *i, - sv => panic!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ), - }) - .collect::>>() - }), - sv => panic!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ), - }), - )) - }}; - } + fn build_list_array( + scalars: impl IntoIterator, + ) -> Result { + let arrays = scalars + .into_iter() + .map(|s| s.to_array()) + .collect::>>()?; - macro_rules! build_array_list_string { - ($BUILDER:ident, $SCALAR_TY:ident) => {{ - let mut builder = ListBuilder::new($BUILDER::new()); - for scalar in scalars.into_iter() { - match scalar { - ScalarValue::List(Some(xs), _) => { - for s in xs { - match s { - ScalarValue::$SCALAR_TY(Some(val)) => { - builder.values().append_value(val); - } - ScalarValue::$SCALAR_TY(None) => { - builder.values().append_null(); - } - sv => { - return _internal_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected Utf8, got {:?}", - sv - ) - } - } - } - builder.append(true); - } - ScalarValue::List(None, _) => { - builder.append(false); - } - sv => { - return _internal_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected List, got {:?}", - sv - ) - } - } + let capacity = Capacities::Array(arrays.iter().map(|arr| arr.len()).sum()); + // ScalarValue::List contains a single element ListArray. + let nulls = arrays + .iter() + .map(|arr| arr.is_null(0)) + .collect::>(); + let arrays_data = arrays.iter().map(|arr| arr.to_data()).collect::>(); + + let arrays_ref = arrays_data.iter().collect::>(); + let mut mutable = + MutableArrayData::with_capacities(arrays_ref, true, capacity); + + // ScalarValue::List contains a single element ListArray. + for (index, is_null) in (0..arrays.len()).zip(nulls.into_iter()) { + if is_null { + mutable.extend_nulls(1) + } else { + mutable.extend(index, 0, 1); } - Arc::new(builder.finish()) - }}; + } + let data = mutable.freeze(); + Ok(arrow_array::make_array(data)) } let array: ArrayRef = match &data_type { @@ -1379,7 +1403,7 @@ impl ScalarValue { ScalarValue::iter_to_decimal256_array(scalars, *precision, *scale)?; Arc::new(decimal_array) } - DataType::Null => ScalarValue::iter_to_null_array(scalars), + DataType::Null => ScalarValue::iter_to_null_array(scalars)?, DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), DataType::Float32 => build_array_primitive!(Float32Array, Float32), DataType::Float64 => build_array_primitive!(Float64Array, Float64), @@ -1442,47 +1466,7 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) => { build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano) } - DataType::List(fields) if fields.data_type() == &DataType::Int8 => { - build_array_list_primitive!(Int8Type, Int8, i8) - } - DataType::List(fields) if fields.data_type() == &DataType::Int16 => { - build_array_list_primitive!(Int16Type, Int16, i16) - } - DataType::List(fields) if fields.data_type() == &DataType::Int32 => { - build_array_list_primitive!(Int32Type, Int32, i32) - } - DataType::List(fields) if fields.data_type() == &DataType::Int64 => { - build_array_list_primitive!(Int64Type, Int64, i64) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { - build_array_list_primitive!(UInt8Type, UInt8, u8) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { - build_array_list_primitive!(UInt16Type, UInt16, u16) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { - build_array_list_primitive!(UInt32Type, UInt32, u32) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { - build_array_list_primitive!(UInt64Type, UInt64, u64) - } - DataType::List(fields) if fields.data_type() == &DataType::Float32 => { - build_array_list_primitive!(Float32Type, Float32, f32) - } - DataType::List(fields) if fields.data_type() == &DataType::Float64 => { - build_array_list_primitive!(Float64Type, Float64, f64) - } - DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { - build_array_list_string!(StringBuilder, Utf8) - } - DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { - build_array_list_string!(LargeStringBuilder, LargeUtf8) - } - DataType::List(_) => { - // Fallback case handling homogeneous lists with any ScalarValue element type - let list_array = ScalarValue::iter_to_array_list(scalars, &data_type)?; - Arc::new(list_array) - } + DataType::List(_) | DataType::LargeList(_) => build_list_array(scalars)?, DataType::Struct(fields) => { // Initialize a Vector to store the ScalarValues for each column let mut columns: Vec> = @@ -1528,7 +1512,7 @@ impl ScalarValue { .collect::>>()?; let array = StructArray::from(field_values); - nullif(&array, &null_mask_builder.finish())? + arrow::compute::nullif(&array, &null_mask_builder.finish())? } DataType::Dictionary(key_type, value_type) => { // create the values array @@ -1538,7 +1522,7 @@ impl ScalarValue { if &inner_key_type == key_type { Ok(*scalar) } else { - panic!("Expected inner key type of {key_type} but found: {inner_key_type}, value was ({scalar:?})"); + _internal_err!("Expected inner key type of {key_type} but found: {inner_key_type}, value was ({scalar:?})") } } _ => { @@ -1595,7 +1579,6 @@ impl ScalarValue { | DataType::Time64(TimeUnit::Millisecond) | DataType::Duration(_) | DataType::FixedSizeList(_, _) - | DataType::LargeList(_) | DataType::Union(_, _) | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) => { @@ -1610,15 +1593,19 @@ impl ScalarValue { Ok(array) } - fn iter_to_null_array(scalars: impl IntoIterator) -> ArrayRef { - let length = - scalars - .into_iter() - .fold(0usize, |r, element: ScalarValue| match element { - ScalarValue::Null => r + 1, - _ => unreachable!(), - }); - new_null_array(&DataType::Null, length) + fn iter_to_null_array( + scalars: impl IntoIterator, + ) -> Result { + let length = scalars.into_iter().try_fold( + 0usize, + |r, element: ScalarValue| match element { + ScalarValue::Null => Ok::(r + 1), + s => { + _internal_err!("Expected ScalarValue::Null element. Received {s:?}") + } + }, + )?; + Ok(new_null_array(&DataType::Null, length)) } fn iter_to_decimal_array( @@ -1629,10 +1616,12 @@ impl ScalarValue { let array = scalars .into_iter() .map(|element: ScalarValue| match element { - ScalarValue::Decimal128(v1, _, _) => v1, - _ => unreachable!(), + ScalarValue::Decimal128(v1, _, _) => Ok(v1), + s => { + _internal_err!("Expected ScalarValue::Null element. Received {s:?}") + } }) - .collect::() + .collect::>()? .with_precision_and_scale(precision, scale)?; Ok(array) } @@ -1645,94 +1634,34 @@ impl ScalarValue { let array = scalars .into_iter() .map(|element: ScalarValue| match element { - ScalarValue::Decimal256(v1, _, _) => v1, - _ => unreachable!(), + ScalarValue::Decimal256(v1, _, _) => Ok(v1), + s => { + _internal_err!( + "Expected ScalarValue::Decimal256 element. Received {s:?}" + ) + } }) - .collect::() + .collect::>()? .with_precision_and_scale(precision, scale)?; Ok(array) } - fn iter_to_array_list( - scalars: impl IntoIterator, - data_type: &DataType, - ) -> Result> { - let mut offsets = Int32Array::builder(0); - offsets.append_value(0); - - let mut elements: Vec = Vec::new(); - let mut valid = BooleanBufferBuilder::new(0); - let mut flat_len = 0i32; - for scalar in scalars { - if let ScalarValue::List(values, field) = scalar { - match values { - Some(values) => { - let element_array = if !values.is_empty() { - ScalarValue::iter_to_array(values)? - } else { - arrow::array::new_empty_array(field.data_type()) - }; - - // Add new offset index - flat_len += element_array.len() as i32; - offsets.append_value(flat_len); - - elements.push(element_array); - - // Element is valid - valid.append(true); - } - None => { - // Repeat previous offset index - offsets.append_value(flat_len); - - // Element is null - valid.append(false); - } - } - } else { - return _internal_err!( - "Expected ScalarValue::List element. Received {scalar:?}" - ); - } - } - - // Concatenate element arrays to create single flat array - let element_arrays: Vec<&dyn Array> = - elements.iter().map(|a| a.as_ref()).collect(); - let flat_array = match arrow::compute::concat(&element_arrays) { - Ok(flat_array) => flat_array, - Err(err) => return Err(DataFusionError::ArrowError(err)), - }; - - // Build ListArray using ArrayData so we can specify a flat inner array, and offset indices - let offsets_array = offsets.finish(); - let array_data = ArrayDataBuilder::new(data_type.clone()) - .len(offsets_array.len() - 1) - .nulls(Some(NullBuffer::new(valid.finish()))) - .add_buffer(offsets_array.values().inner().clone()) - .add_child_data(flat_array.to_data()); - - let list_array = ListArray::from(array_data.build()?); - Ok(list_array) - } - fn build_decimal_array( value: Option, precision: u8, scale: i8, size: usize, - ) -> Decimal128Array { + ) -> Result { match value { Some(val) => Decimal128Array::from(vec![val; size]) .with_precision_and_scale(precision, scale) - .unwrap(), + .map_err(|e| arrow_datafusion_err!(e)), None => { let mut builder = Decimal128Array::builder(size) .with_precision_and_scale(precision, scale) - .unwrap(); + .map_err(|e| arrow_datafusion_err!(e))?; builder.append_nulls(size); - builder.finish() + Ok(builder.finish()) } } } @@ -1742,22 +1671,100 @@ impl ScalarValue { precision: u8, scale: i8, size: usize, - ) -> Decimal256Array { + ) -> Result { std::iter::repeat(value) .take(size) .collect::() .with_precision_and_scale(precision, scale) - .unwrap() + .map_err(|e| arrow_datafusion_err!(e)) + } + + /// Converts `Vec` where each element has type corresponding to + /// `data_type`, to a [`ListArray`]. + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::{ListArray, Int32Array}; + /// use arrow::datatypes::{DataType, Int32Type}; + /// use datafusion_common::cast::as_list_array; + /// + /// let scalars = vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(None), + /// ScalarValue::Int32(Some(2)) + /// ]; + /// + /// let array = ScalarValue::new_list(&scalars, &DataType::Int32); + /// let result = as_list_array(&array).unwrap(); + /// + /// let expected = ListArray::from_iter_primitive::( + /// vec![ + /// Some(vec![Some(1), None, Some(2)]) + /// ]); + /// + /// assert_eq!(result, &expected); + /// ``` + pub fn new_list(values: &[ScalarValue], data_type: &DataType) -> ArrayRef { + let values = if values.is_empty() { + new_empty_array(data_type) + } else { + Self::iter_to_array(values.iter().cloned()).unwrap() + }; + Arc::new(array_into_list_array(values)) + } + + /// Converts `Vec` where each element has type corresponding to + /// `data_type`, to a [`LargeListArray`]. + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::{LargeListArray, Int32Array}; + /// use arrow::datatypes::{DataType, Int32Type}; + /// use datafusion_common::cast::as_large_list_array; + /// + /// let scalars = vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(None), + /// ScalarValue::Int32(Some(2)) + /// ]; + /// + /// let array = ScalarValue::new_large_list(&scalars, &DataType::Int32); + /// let result = as_large_list_array(&array).unwrap(); + /// + /// let expected = LargeListArray::from_iter_primitive::( + /// vec![ + /// Some(vec![Some(1), None, Some(2)]) + /// ]); + /// + /// assert_eq!(result, &expected); + /// ``` + pub fn new_large_list(values: &[ScalarValue], data_type: &DataType) -> ArrayRef { + let values = if values.is_empty() { + new_empty_array(data_type) + } else { + Self::iter_to_array(values.iter().cloned()).unwrap() + }; + Arc::new(array_into_large_list_array(values)) } /// Converts a scalar value into an array of `size` rows. - pub fn to_array_of_size(&self, size: usize) -> ArrayRef { - match self { + /// + /// # Errors + /// + /// Errors if `self` is + /// - a decimal that fails be converted to a decimal array of size + /// - a `Fixedsizelist` that fails to be concatenated into an array of size + /// - a `List` that fails to be concatenated into an array of size + /// - a `Dictionary` that fails be converted to a dictionary array of size + pub fn to_array_of_size(&self, size: usize) -> Result { + Ok(match self { ScalarValue::Decimal128(e, precision, scale) => Arc::new( - ScalarValue::build_decimal_array(*e, *precision, *scale, size), + ScalarValue::build_decimal_array(*e, *precision, *scale, size)?, ), ScalarValue::Decimal256(e, precision, scale) => Arc::new( - ScalarValue::build_decimal256_array(*e, *precision, *scale, size), + ScalarValue::build_decimal256_array(*e, *precision, *scale, size)?, ), ScalarValue::Boolean(e) => { Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef @@ -1869,38 +1876,15 @@ impl ScalarValue { .collect::(), ), }, - ScalarValue::Fixedsizelist(..) => { - unimplemented!("FixedSizeList is not supported yet") - } - ScalarValue::List(values, field) => Arc::new(match field.data_type() { - DataType::Boolean => build_list!(BooleanBuilder, Boolean, values, size), - DataType::Int8 => build_list!(Int8Builder, Int8, values, size), - DataType::Int16 => build_list!(Int16Builder, Int16, values, size), - DataType::Int32 => build_list!(Int32Builder, Int32, values, size), - DataType::Int64 => build_list!(Int64Builder, Int64, values, size), - DataType::UInt8 => build_list!(UInt8Builder, UInt8, values, size), - DataType::UInt16 => build_list!(UInt16Builder, UInt16, values, size), - DataType::UInt32 => build_list!(UInt32Builder, UInt32, values, size), - DataType::UInt64 => build_list!(UInt64Builder, UInt64, values, size), - DataType::Utf8 => build_list!(StringBuilder, Utf8, values, size), - DataType::Float32 => build_list!(Float32Builder, Float32, values, size), - DataType::Float64 => build_list!(Float64Builder, Float64, values, size), - DataType::Timestamp(unit, tz) => { - build_timestamp_list!(unit.clone(), tz.clone(), values, size) - } - &DataType::LargeUtf8 => { - build_list!(LargeStringBuilder, LargeUtf8, values, size) - } - _ => ScalarValue::iter_to_array_list( - repeat(self.clone()).take(size), - &DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - ))), - ) - .unwrap(), - }), + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { + let arrays = std::iter::repeat(arr.as_ref()) + .take(size) + .collect::>(); + arrow::compute::concat(arrays.as_slice()) + .map_err(|e| arrow_datafusion_err!(e))? + } ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) } @@ -1994,13 +1978,13 @@ impl ScalarValue { ), ScalarValue::Struct(values, fields) => match values { Some(values) => { - let field_values: Vec<_> = fields + let field_values = fields .iter() .zip(values.iter()) .map(|(field, value)| { - (field.clone(), value.to_array_of_size(size)) + Ok((field.clone(), value.to_array_of_size(size)?)) }) - .collect(); + .collect::>>()?; Arc::new(StructArray::from(field_values)) } @@ -2012,19 +1996,19 @@ impl ScalarValue { ScalarValue::Dictionary(key_type, v) => { // values array is one element long (the value) match key_type.as_ref() { - DataType::Int8 => dict_from_scalar::(v, size), - DataType::Int16 => dict_from_scalar::(v, size), - DataType::Int32 => dict_from_scalar::(v, size), - DataType::Int64 => dict_from_scalar::(v, size), - DataType::UInt8 => dict_from_scalar::(v, size), - DataType::UInt16 => dict_from_scalar::(v, size), - DataType::UInt32 => dict_from_scalar::(v, size), - DataType::UInt64 => dict_from_scalar::(v, size), + DataType::Int8 => dict_from_scalar::(v, size)?, + DataType::Int16 => dict_from_scalar::(v, size)?, + DataType::Int32 => dict_from_scalar::(v, size)?, + DataType::Int64 => dict_from_scalar::(v, size)?, + DataType::UInt8 => dict_from_scalar::(v, size)?, + DataType::UInt16 => dict_from_scalar::(v, size)?, + DataType::UInt32 => dict_from_scalar::(v, size)?, + DataType::UInt64 => dict_from_scalar::(v, size)?, _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), } } ScalarValue::Null => new_null_array(&DataType::Null, size), - } + }) } fn get_decimal_value_from_array( @@ -2056,6 +2040,71 @@ impl ScalarValue { } } + /// Retrieve ScalarValue for each row in `array` + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::ListArray; + /// use arrow::datatypes::{DataType, Int32Type}; + /// + /// let list_arr = ListArray::from_iter_primitive::(vec![ + /// Some(vec![Some(1), Some(2), Some(3)]), + /// None, + /// Some(vec![Some(4), Some(5)]) + /// ]); + /// + /// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap(); + /// + /// let expected = vec![ + /// vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(2)), + /// ScalarValue::Int32(Some(3)), + /// ], + /// vec![], + /// vec![ScalarValue::Int32(Some(4)), ScalarValue::Int32(Some(5))] + /// ]; + /// + /// assert_eq!(scalar_vec, expected); + /// ``` + pub fn convert_array_to_scalar_vec(array: &dyn Array) -> Result>> { + let mut scalars = Vec::with_capacity(array.len()); + + for index in 0..array.len() { + let scalar_values = match array.data_type() { + DataType::List(_) => { + let list_array = as_list_array(array); + match list_array.is_null(index) { + true => Vec::new(), + false => { + let nested_array = list_array.value(index); + ScalarValue::convert_array_to_scalar_vec(&nested_array)? + .into_iter() + .flatten() + .collect() + } + } + } + _ => { + let scalar = ScalarValue::try_from_array(array, index)?; + vec![scalar] + } + }; + scalars.push(scalar_values); + } + Ok(scalars) + } + + // TODO: Support more types after other ScalarValue is wrapped with ArrayRef + /// Get raw data (inner array) inside ScalarValue + pub fn raw_data(&self) -> Result { + match self { + ScalarValue::List(arr) => Ok(arr.to_owned()), + _ => _internal_err!("ScalarValue is not a list"), + } + } + /// Converts a value in `array` at `index` into a ScalarValue pub fn try_from_array(array: &dyn Array, index: usize) -> Result { // handle NULL value @@ -2075,101 +2124,102 @@ impl ScalarValue { array, index, *precision, *scale, )? } - DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), - DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), - DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), - DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64), - DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32), - DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16), - DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8), - DataType::Int64 => typed_cast!(array, index, Int64Array, Int64), - DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), - DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), - DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), - DataType::Binary => typed_cast!(array, index, BinaryArray, Binary), + DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean)?, + DataType::Float64 => typed_cast!(array, index, Float64Array, Float64)?, + DataType::Float32 => typed_cast!(array, index, Float32Array, Float32)?, + DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64)?, + DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32)?, + DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16)?, + DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8)?, + DataType::Int64 => typed_cast!(array, index, Int64Array, Int64)?, + DataType::Int32 => typed_cast!(array, index, Int32Array, Int32)?, + DataType::Int16 => typed_cast!(array, index, Int16Array, Int16)?, + DataType::Int8 => typed_cast!(array, index, Int8Array, Int8)?, + DataType::Binary => typed_cast!(array, index, BinaryArray, Binary)?, DataType::LargeBinary => { - typed_cast!(array, index, LargeBinaryArray, LargeBinary) - } - DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), - DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), - DataType::List(nested_type) => { - let list_array = as_list_array(array)?; - let value = match list_array.is_null(index) { - true => None, - false => { - let nested_array = list_array.value(index); - let scalar_vec = (0..nested_array.len()) - .map(|i| ScalarValue::try_from_array(&nested_array, i)) - .collect::>>()?; - Some(scalar_vec) - } - }; - ScalarValue::new_list(value, nested_type.data_type().clone()) + typed_cast!(array, index, LargeBinaryArray, LargeBinary)? + } + DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8)?, + DataType::LargeUtf8 => { + typed_cast!(array, index, LargeStringArray, LargeUtf8)? + } + DataType::List(_) => { + let list_array = as_list_array(array); + let nested_array = list_array.value(index); + // Produces a single element `ListArray` with the value at `index`. + let arr = Arc::new(array_into_list_array(nested_array)); + + ScalarValue::List(arr) } - DataType::Date32 => { - typed_cast!(array, index, Date32Array, Date32) + DataType::LargeList(_) => { + let list_array = as_large_list_array(array); + let nested_array = list_array.value(index); + // Produces a single element `LargeListArray` with the value at `index`. + let arr = Arc::new(array_into_large_list_array(nested_array)); + + ScalarValue::LargeList(arr) } - DataType::Date64 => { - typed_cast!(array, index, Date64Array, Date64) + // TODO: There is no test for FixedSizeList now, add it later + DataType::FixedSizeList(_, _) => { + let list_array = as_fixed_size_list_array(array)?; + let nested_array = list_array.value(index); + // Produces a single element `ListArray` with the value at `index`. + let arr = Arc::new(array_into_list_array(nested_array)); + + ScalarValue::List(arr) } + DataType::Date32 => typed_cast!(array, index, Date32Array, Date32)?, + DataType::Date64 => typed_cast!(array, index, Date64Array, Date64)?, DataType::Time32(TimeUnit::Second) => { - typed_cast!(array, index, Time32SecondArray, Time32Second) + typed_cast!(array, index, Time32SecondArray, Time32Second)? } DataType::Time32(TimeUnit::Millisecond) => { - typed_cast!(array, index, Time32MillisecondArray, Time32Millisecond) + typed_cast!(array, index, Time32MillisecondArray, Time32Millisecond)? } DataType::Time64(TimeUnit::Microsecond) => { - typed_cast!(array, index, Time64MicrosecondArray, Time64Microsecond) + typed_cast!(array, index, Time64MicrosecondArray, Time64Microsecond)? } DataType::Time64(TimeUnit::Nanosecond) => { - typed_cast!(array, index, Time64NanosecondArray, Time64Nanosecond) - } - DataType::Timestamp(TimeUnit::Second, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampSecondArray, - TimestampSecond, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampMillisecondArray, - TimestampMillisecond, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampMicrosecondArray, - TimestampMicrosecond, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampNanosecondArray, - TimestampNanosecond, - tz_opt - ) + typed_cast!(array, index, Time64NanosecondArray, Time64Nanosecond)? } + DataType::Timestamp(TimeUnit::Second, tz_opt) => typed_cast_tz!( + array, + index, + TimestampSecondArray, + TimestampSecond, + tz_opt + )?, + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_cast_tz!( + array, + index, + TimestampMillisecondArray, + TimestampMillisecond, + tz_opt + )?, + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_cast_tz!( + array, + index, + TimestampMicrosecondArray, + TimestampMicrosecond, + tz_opt + )?, + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_cast_tz!( + array, + index, + TimestampNanosecondArray, + TimestampNanosecond, + tz_opt + )?, DataType::Dictionary(key_type, _) => { let (values_array, values_index) = match key_type.as_ref() { - DataType::Int8 => get_dict_value::(array, index), - DataType::Int16 => get_dict_value::(array, index), - DataType::Int32 => get_dict_value::(array, index), - DataType::Int64 => get_dict_value::(array, index), - DataType::UInt8 => get_dict_value::(array, index), - DataType::UInt16 => get_dict_value::(array, index), - DataType::UInt32 => get_dict_value::(array, index), - DataType::UInt64 => get_dict_value::(array, index), + DataType::Int8 => get_dict_value::(array, index)?, + DataType::Int16 => get_dict_value::(array, index)?, + DataType::Int32 => get_dict_value::(array, index)?, + DataType::Int64 => get_dict_value::(array, index)?, + DataType::UInt8 => get_dict_value::(array, index)?, + DataType::UInt16 => get_dict_value::(array, index)?, + DataType::UInt32 => get_dict_value::(array, index)?, + DataType::UInt64 => get_dict_value::(array, index)?, _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), }; // look up the index in the values dictionary @@ -2193,20 +2243,6 @@ impl ScalarValue { } Self::Struct(Some(field_values), fields.clone()) } - DataType::FixedSizeList(nested_type, _len) => { - let list_array = as_fixed_size_list_array(array)?; - let value = match list_array.is_null(index) { - true => None, - false => { - let nested_array = list_array.value(index); - let scalar_vec = (0..nested_array.len()) - .map(|i| ScalarValue::try_from_array(&nested_array, i)) - .collect::>>()?; - Some(scalar_vec) - } - }; - ScalarValue::new_list(value, nested_type.data_type().clone()) - } DataType::FixedSizeBinary(_) => { let array = as_fixed_size_binary_array(array)?; let size = match array.data_type() { @@ -2222,31 +2258,29 @@ impl ScalarValue { ) } DataType::Interval(IntervalUnit::DayTime) => { - typed_cast!(array, index, IntervalDayTimeArray, IntervalDayTime) + typed_cast!(array, index, IntervalDayTimeArray, IntervalDayTime)? } DataType::Interval(IntervalUnit::YearMonth) => { - typed_cast!(array, index, IntervalYearMonthArray, IntervalYearMonth) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - typed_cast!( - array, - index, - IntervalMonthDayNanoArray, - IntervalMonthDayNano - ) + typed_cast!(array, index, IntervalYearMonthArray, IntervalYearMonth)? } + DataType::Interval(IntervalUnit::MonthDayNano) => typed_cast!( + array, + index, + IntervalMonthDayNanoArray, + IntervalMonthDayNano + )?, DataType::Duration(TimeUnit::Second) => { - typed_cast!(array, index, DurationSecondArray, DurationSecond) + typed_cast!(array, index, DurationSecondArray, DurationSecond)? } DataType::Duration(TimeUnit::Millisecond) => { - typed_cast!(array, index, DurationMillisecondArray, DurationMillisecond) + typed_cast!(array, index, DurationMillisecondArray, DurationMillisecond)? } DataType::Duration(TimeUnit::Microsecond) => { - typed_cast!(array, index, DurationMicrosecondArray, DurationMicrosecond) + typed_cast!(array, index, DurationMicrosecondArray, DurationMicrosecond)? } DataType::Duration(TimeUnit::Nanosecond) => { - typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond) + typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond)? } other => { @@ -2259,12 +2293,12 @@ impl ScalarValue { /// Try to parse `value` into a ScalarValue of type `target_type` pub fn try_from_string(value: String, target_type: &DataType) -> Result { - let value = ScalarValue::Utf8(Some(value)); + let value = ScalarValue::from(value); let cast_options = CastOptions { safe: false, format_options: Default::default(), }; - let cast_arr = cast_with_options(&value.to_array(), target_type, &cast_options)?; + let cast_arr = cast_with_options(&value.to_array()?, target_type, &cast_options)?; ScalarValue::try_from_array(&cast_arr, 0) } @@ -2322,9 +2356,19 @@ impl ScalarValue { /// /// This function has a few narrow usescases such as hash table key /// comparisons where comparing a single row at a time is necessary. + /// + /// # Errors + /// + /// Errors if + /// - it fails to downcast `array` to the data type of `self` + /// - `self` is a `Struct` + /// + /// # Panics + /// + /// Panics if `self` is a dictionary with invalid key type #[inline] - pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool { - match self { + pub fn eq_array(&self, array: &ArrayRef, index: usize) -> Result { + Ok(match self { ScalarValue::Decimal128(v, precision, scale) => { ScalarValue::eq_array_decimal( array, @@ -2332,8 +2376,7 @@ impl ScalarValue { v.as_ref(), *precision, *scale, - ) - .unwrap() + )? } ScalarValue::Decimal256(v, precision, scale) => { ScalarValue::eq_array_decimal256( @@ -2342,119 +2385,134 @@ impl ScalarValue { v.as_ref(), *precision, *scale, - ) - .unwrap() + )? } ScalarValue::Boolean(val) => { - eq_array_primitive!(array, index, BooleanArray, val) + eq_array_primitive!(array, index, BooleanArray, val)? } ScalarValue::Float32(val) => { - eq_array_primitive!(array, index, Float32Array, val) + eq_array_primitive!(array, index, Float32Array, val)? } ScalarValue::Float64(val) => { - eq_array_primitive!(array, index, Float64Array, val) + eq_array_primitive!(array, index, Float64Array, val)? + } + ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val)?, + ScalarValue::Int16(val) => { + eq_array_primitive!(array, index, Int16Array, val)? + } + ScalarValue::Int32(val) => { + eq_array_primitive!(array, index, Int32Array, val)? + } + ScalarValue::Int64(val) => { + eq_array_primitive!(array, index, Int64Array, val)? + } + ScalarValue::UInt8(val) => { + eq_array_primitive!(array, index, UInt8Array, val)? } - ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val), - ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val), - ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val), - ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val), - ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val), ScalarValue::UInt16(val) => { - eq_array_primitive!(array, index, UInt16Array, val) + eq_array_primitive!(array, index, UInt16Array, val)? } ScalarValue::UInt32(val) => { - eq_array_primitive!(array, index, UInt32Array, val) + eq_array_primitive!(array, index, UInt32Array, val)? } ScalarValue::UInt64(val) => { - eq_array_primitive!(array, index, UInt64Array, val) + eq_array_primitive!(array, index, UInt64Array, val)? + } + ScalarValue::Utf8(val) => { + eq_array_primitive!(array, index, StringArray, val)? } - ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val), ScalarValue::LargeUtf8(val) => { - eq_array_primitive!(array, index, LargeStringArray, val) + eq_array_primitive!(array, index, LargeStringArray, val)? } ScalarValue::Binary(val) => { - eq_array_primitive!(array, index, BinaryArray, val) + eq_array_primitive!(array, index, BinaryArray, val)? } ScalarValue::FixedSizeBinary(_, val) => { - eq_array_primitive!(array, index, FixedSizeBinaryArray, val) + eq_array_primitive!(array, index, FixedSizeBinaryArray, val)? } ScalarValue::LargeBinary(val) => { - eq_array_primitive!(array, index, LargeBinaryArray, val) + eq_array_primitive!(array, index, LargeBinaryArray, val)? + } + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { + let right = array.slice(index, 1); + arr == &right } - ScalarValue::Fixedsizelist(..) => unimplemented!(), - ScalarValue::List(_, _) => unimplemented!(), ScalarValue::Date32(val) => { - eq_array_primitive!(array, index, Date32Array, val) + eq_array_primitive!(array, index, Date32Array, val)? } ScalarValue::Date64(val) => { - eq_array_primitive!(array, index, Date64Array, val) + eq_array_primitive!(array, index, Date64Array, val)? } ScalarValue::Time32Second(val) => { - eq_array_primitive!(array, index, Time32SecondArray, val) + eq_array_primitive!(array, index, Time32SecondArray, val)? } ScalarValue::Time32Millisecond(val) => { - eq_array_primitive!(array, index, Time32MillisecondArray, val) + eq_array_primitive!(array, index, Time32MillisecondArray, val)? } ScalarValue::Time64Microsecond(val) => { - eq_array_primitive!(array, index, Time64MicrosecondArray, val) + eq_array_primitive!(array, index, Time64MicrosecondArray, val)? } ScalarValue::Time64Nanosecond(val) => { - eq_array_primitive!(array, index, Time64NanosecondArray, val) + eq_array_primitive!(array, index, Time64NanosecondArray, val)? } ScalarValue::TimestampSecond(val, _) => { - eq_array_primitive!(array, index, TimestampSecondArray, val) + eq_array_primitive!(array, index, TimestampSecondArray, val)? } ScalarValue::TimestampMillisecond(val, _) => { - eq_array_primitive!(array, index, TimestampMillisecondArray, val) + eq_array_primitive!(array, index, TimestampMillisecondArray, val)? } ScalarValue::TimestampMicrosecond(val, _) => { - eq_array_primitive!(array, index, TimestampMicrosecondArray, val) + eq_array_primitive!(array, index, TimestampMicrosecondArray, val)? } ScalarValue::TimestampNanosecond(val, _) => { - eq_array_primitive!(array, index, TimestampNanosecondArray, val) + eq_array_primitive!(array, index, TimestampNanosecondArray, val)? } ScalarValue::IntervalYearMonth(val) => { - eq_array_primitive!(array, index, IntervalYearMonthArray, val) + eq_array_primitive!(array, index, IntervalYearMonthArray, val)? } ScalarValue::IntervalDayTime(val) => { - eq_array_primitive!(array, index, IntervalDayTimeArray, val) + eq_array_primitive!(array, index, IntervalDayTimeArray, val)? } ScalarValue::IntervalMonthDayNano(val) => { - eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val) + eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val)? } ScalarValue::DurationSecond(val) => { - eq_array_primitive!(array, index, DurationSecondArray, val) + eq_array_primitive!(array, index, DurationSecondArray, val)? } ScalarValue::DurationMillisecond(val) => { - eq_array_primitive!(array, index, DurationMillisecondArray, val) + eq_array_primitive!(array, index, DurationMillisecondArray, val)? } ScalarValue::DurationMicrosecond(val) => { - eq_array_primitive!(array, index, DurationMicrosecondArray, val) + eq_array_primitive!(array, index, DurationMicrosecondArray, val)? } ScalarValue::DurationNanosecond(val) => { - eq_array_primitive!(array, index, DurationNanosecondArray, val) + eq_array_primitive!(array, index, DurationNanosecondArray, val)? + } + ScalarValue::Struct(_, _) => { + return _not_impl_err!("Struct is not supported yet") } - ScalarValue::Struct(_, _) => unimplemented!(), ScalarValue::Dictionary(key_type, v) => { let (values_array, values_index) = match key_type.as_ref() { - DataType::Int8 => get_dict_value::(array, index), - DataType::Int16 => get_dict_value::(array, index), - DataType::Int32 => get_dict_value::(array, index), - DataType::Int64 => get_dict_value::(array, index), - DataType::UInt8 => get_dict_value::(array, index), - DataType::UInt16 => get_dict_value::(array, index), - DataType::UInt32 => get_dict_value::(array, index), - DataType::UInt64 => get_dict_value::(array, index), + DataType::Int8 => get_dict_value::(array, index)?, + DataType::Int16 => get_dict_value::(array, index)?, + DataType::Int32 => get_dict_value::(array, index)?, + DataType::Int64 => get_dict_value::(array, index)?, + DataType::UInt8 => get_dict_value::(array, index)?, + DataType::UInt16 => get_dict_value::(array, index)?, + DataType::UInt32 => get_dict_value::(array, index)?, + DataType::UInt64 => get_dict_value::(array, index)?, _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), }; // was the value in the array non null? match values_index { - Some(values_index) => v.eq_array(values_array, values_index), + Some(values_index) => v.eq_array(values_array, values_index)?, None => v.is_null(), } } ScalarValue::Null => array.is_null(index), - } + }) } /// Estimate size if bytes including `Self`. For values with internal containers such as `String` @@ -2503,14 +2561,9 @@ impl ScalarValue { | ScalarValue::LargeBinary(b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } - ScalarValue::Fixedsizelist(vals, field, _) - | ScalarValue::List(vals, field) => { - vals.as_ref() - .map(|vals| Self::size_of_vec(vals) - std::mem::size_of_val(vals)) - .unwrap_or_default() - // `field` is boxed, so it is NOT already included in `self` - + field.size() - } + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), ScalarValue::Struct(vals, fields) => { vals.as_ref() .map(|vals| { @@ -2606,6 +2659,12 @@ impl FromStr for ScalarValue { } } +impl From for ScalarValue { + fn from(value: String) -> Self { + ScalarValue::Utf8(Some(value)) + } +} + impl From> for ScalarValue { fn from(value: Vec<(&str, ScalarValue)>) -> Self { let (fields, scalars): (SchemaBuilder, Vec<_>) = value @@ -2734,8 +2793,8 @@ impl TryFrom<&DataType> for ScalarValue { type Error = DataFusionError; /// Create a Null instance of ScalarValue for this datatype - fn try_from(datatype: &DataType) -> Result { - Ok(match datatype { + fn try_from(data_type: &DataType) -> Result { + Ok(match data_type { DataType::Boolean => ScalarValue::Boolean(None), DataType::Float64 => ScalarValue::Float64(None), DataType::Float32 => ScalarValue::Float32(None), @@ -2805,14 +2864,20 @@ impl TryFrom<&DataType> for ScalarValue { index_type.clone(), Box::new(value_type.as_ref().try_into()?), ), - DataType::List(ref nested_type) => { - ScalarValue::new_list(None, nested_type.data_type().clone()) - } + // `ScalaValue::List` contains single element `ListArray`. + DataType::List(field) => ScalarValue::List(new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + ))), + 1, + )), DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()), DataType::Null => ScalarValue::Null, _ => { return _not_impl_err!( - "Can't create a scalar from data_type \"{datatype:?}\"" + "Can't create a scalar from data_type \"{data_type:?}\"" ); } }) @@ -2828,6 +2893,11 @@ macro_rules! format_option { }}; } +// Implement Display trait for ScalarValue +// +// # Panics +// +// Panics if there is an error when creating a visual representation of columns via `arrow::util::pretty` impl fmt::Display for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -2867,17 +2937,16 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, - ScalarValue::Fixedsizelist(e, ..) | ScalarValue::List(e, _) => match e { - Some(l) => write!( - f, - "{}", - l.iter() - .map(|v| format!("{v}")) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { + // ScalarValue List should always have a single element + assert_eq!(arr.len(), 1); + let options = FormatOptions::default().with_display_error(true); + let formatter = ArrayFormatter::try_new(arr, &options).unwrap(); + let value_formatter = formatter.value(0); + write!(f, "{value_formatter}")? + } ScalarValue::Date32(e) => format_option!(f, e)?, ScalarValue::Date64(e) => format_option!(f, e)?, ScalarValue::Time32Second(e) => format_option!(f, e)?, @@ -2952,8 +3021,9 @@ impl fmt::Debug for ScalarValue { } ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({self})"), ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{self}\")"), - ScalarValue::Fixedsizelist(..) => write!(f, "FixedSizeList([{self}])"), - ScalarValue::List(_, _) => write!(f, "List([{self}])"), + ScalarValue::FixedSizeList(_) => write!(f, "FixedSizeList({self})"), + ScalarValue::List(_) => write!(f, "List({self})"), + ScalarValue::LargeList(_) => write!(f, "LargeList({self})"), ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"), ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"), ScalarValue::Time32Second(_) => write!(f, "Time32Second(\"{self}\")"), @@ -3044,20 +3114,219 @@ impl ScalarType for TimestampNanosecondType { #[cfg(test)] mod tests { + use super::*; + use std::cmp::Ordering; use std::sync::Arc; + use chrono::NaiveDate; + use rand::Rng; + + use arrow::buffer::OffsetBuffer; use arrow::compute::kernels; use arrow::compute::{concat, is_null}; use arrow::datatypes::ArrowPrimitiveType; use arrow::util::pretty::pretty_format_columns; use arrow_array::ArrowNumericType; - use chrono::NaiveDate; - use rand::Rng; use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; - use super::*; + #[test] + fn test_to_array_of_size_for_list() { + let arr = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + None, + Some(2), + ])]); + + let sv = ScalarValue::List(Arc::new(arr)); + let actual_arr = sv + .to_array_of_size(2) + .expect("Failed to convert to array of size"); + let actual_list_arr = as_list_array(&actual_arr); + + let arr = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), None, Some(2)]), + Some(vec![Some(1), None, Some(2)]), + ]); + + assert_eq!(&arr, actual_list_arr); + } + + #[test] + fn test_to_array_of_size_for_fsl() { + let values = Int32Array::from_iter([Some(1), None, Some(2)]); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let arr = FixedSizeListArray::new(field.clone(), 3, Arc::new(values), None); + let sv = ScalarValue::FixedSizeList(Arc::new(arr)); + let actual_arr = sv + .to_array_of_size(2) + .expect("Failed to convert to array of size"); + + let expected_values = + Int32Array::from_iter([Some(1), None, Some(2), Some(1), None, Some(2)]); + let expected_arr = + FixedSizeListArray::new(field, 3, Arc::new(expected_values), None); + + assert_eq!( + &expected_arr, + as_fixed_size_list_array(actual_arr.as_ref()).unwrap() + ); + } + + #[test] + fn test_list_to_array_string() { + let scalars = vec![ + ScalarValue::from("rust"), + ScalarValue::from("arrow"), + ScalarValue::from("data-fusion"), + ]; + + let array = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); + + let expected = array_into_list_array(Arc::new(StringArray::from(vec![ + "rust", + "arrow", + "data-fusion", + ]))); + let result = as_list_array(&array); + assert_eq!(result, &expected); + } + + fn build_list( + values: Vec>>>, + ) -> Vec { + values + .into_iter() + .map(|v| { + let arr = if v.is_some() { + Arc::new( + GenericListArray::::from_iter_primitive::( + vec![v], + ), + ) + } else if O::IS_LARGE { + new_null_array( + &DataType::LargeList(Arc::new(Field::new( + "item", + DataType::Int64, + true, + ))), + 1, + ) + } else { + new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + DataType::Int64, + true, + ))), + 1, + ) + }; + + if O::IS_LARGE { + ScalarValue::LargeList(arr) + } else { + ScalarValue::List(arr) + } + }) + .collect() + } + + #[test] + fn iter_to_array_primitive_test() { + // List[[1,2,3]], List[null], List[[4,5]] + let scalars = build_list::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let list_array = as_list_array(&array); + // List[[1,2,3], null, [4,5]] + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + assert_eq!(list_array, &expected); + + let scalars = build_list::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let list_array = as_large_list_array(&array); + let expected = LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + assert_eq!(list_array, &expected); + } + + #[test] + fn iter_to_array_string_test() { + let arr1 = + array_into_list_array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let arr2 = + array_into_list_array(Arc::new(StringArray::from(vec!["rust", "world"]))); + + let scalars = vec![ + ScalarValue::List(Arc::new(arr1)), + ScalarValue::List(Arc::new(arr2)), + ]; + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let result = as_list_array(&array); + + // build expected array + let string_builder = StringBuilder::with_capacity(5, 25); + let mut list_of_string_builder = ListBuilder::new(string_builder); + + list_of_string_builder.values().append_value("foo"); + list_of_string_builder.values().append_value("bar"); + list_of_string_builder.values().append_value("baz"); + list_of_string_builder.append(true); + + list_of_string_builder.values().append_value("rust"); + list_of_string_builder.values().append_value("world"); + list_of_string_builder.append(true); + let expected = list_of_string_builder.finish(); + + assert_eq!(result, &expected); + } + + #[test] + fn test_list_scalar_eq_to_array() { + let list_array: ArrayRef = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![None, Some(5)]), + ])); + + let fsl_array: ArrayRef = + Arc::new(FixedSizeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + ], + 3, + )); + + for arr in [list_array, fsl_array] { + for i in 0..arr.len() { + let scalar = ScalarValue::List(arr.slice(i, 1)); + assert!(scalar.eq_array(&arr, i).unwrap()); + } + } + } #[test] fn scalar_add_trait_test() -> Result<()> { @@ -3177,8 +3446,8 @@ mod tests { { let scalar_result = left.add_checked(&right); - let left_array = left.to_array(); - let right_array = right.to_array(); + let left_array = left.to_array().expect("Failed to convert to array"); + let right_array = right.to_array().expect("Failed to convert to array"); let arrow_left_array = left_array.as_primitive::(); let arrow_right_array = right_array.as_primitive::(); let arrow_result = kernels::numeric::add(arrow_left_array, arrow_right_array); @@ -3226,22 +3495,30 @@ mod tests { } // decimal scalar to array - let array = decimal_value.to_array(); + let array = decimal_value + .to_array() + .expect("Failed to convert to array"); let array = as_decimal128_array(&array)?; assert_eq!(1, array.len()); assert_eq!(DataType::Decimal128(10, 1), array.data_type().clone()); assert_eq!(123i128, array.value(0)); // decimal scalar to array with size - let array = decimal_value.to_array_of_size(10); + let array = decimal_value + .to_array_of_size(10) + .expect("Failed to convert to array of size"); let array_decimal = as_decimal128_array(&array)?; assert_eq!(10, array.len()); assert_eq!(DataType::Decimal128(10, 1), array.data_type().clone()); assert_eq!(123i128, array_decimal.value(0)); assert_eq!(123i128, array_decimal.value(9)); // test eq array - assert!(decimal_value.eq_array(&array, 1)); - assert!(decimal_value.eq_array(&array, 5)); + assert!(decimal_value + .eq_array(&array, 1) + .expect("Failed to compare arrays")); + assert!(decimal_value + .eq_array(&array, 5) + .expect("Failed to compare arrays")); // test try from array assert_eq!( decimal_value, @@ -3288,13 +3565,16 @@ mod tests { assert!(ScalarValue::try_new_decimal128(1, 10, 2) .unwrap() - .eq_array(&array, 0)); + .eq_array(&array, 0) + .expect("Failed to compare arrays")); assert!(ScalarValue::try_new_decimal128(2, 10, 2) .unwrap() - .eq_array(&array, 1)); + .eq_array(&array, 1) + .expect("Failed to compare arrays")); assert!(ScalarValue::try_new_decimal128(3, 10, 2) .unwrap() - .eq_array(&array, 2)); + .eq_array(&array, 2) + .expect("Failed to compare arrays")); assert_eq!( ScalarValue::Decimal128(None, 10, 2), ScalarValue::try_from_array(&array, 3).unwrap() @@ -3303,17 +3583,74 @@ mod tests { Ok(()) } + #[test] + fn test_list_partial_cmp() { + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Equal)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(10), + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(30), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(10), + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(10), + Some(2), + Some(30), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Less)); + } + #[test] fn scalar_value_to_array_u64() -> Result<()> { let value = ScalarValue::UInt64(Some(13u64)); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint64_array(&array)?; assert_eq!(array.len(), 1); assert!(!array.is_null(0)); assert_eq!(array.value(0), 13); let value = ScalarValue::UInt64(None); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint64_array(&array)?; assert_eq!(array.len(), 1); assert!(array.is_null(0)); @@ -3323,14 +3660,14 @@ mod tests { #[test] fn scalar_value_to_array_u32() -> Result<()> { let value = ScalarValue::UInt32(Some(13u32)); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint32_array(&array)?; assert_eq!(array.len(), 1); assert!(!array.is_null(0)); assert_eq!(array.value(0), 13); let value = ScalarValue::UInt32(None); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint32_array(&array)?; assert_eq!(array.len(), 1); assert!(array.is_null(0)); @@ -3339,31 +3676,52 @@ mod tests { #[test] fn scalar_list_null_to_array() { - let list_array_ref = ScalarValue::List( - None, - Arc::new(Field::new("item", DataType::UInt64, false)), - ) - .to_array(); - let list_array = as_list_array(&list_array_ref).unwrap(); + let list_array_ref = ScalarValue::new_list(&[], &DataType::UInt64); + let list_array = as_list_array(&list_array_ref); + + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 0); + } + + #[test] + fn scalar_large_list_null_to_array() { + let list_array_ref = ScalarValue::new_large_list(&[], &DataType::UInt64); + let list_array = as_large_list_array(&list_array_ref); - assert!(list_array.is_null(0)); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 0); } #[test] fn scalar_list_to_array() -> Result<()> { - let list_array_ref = ScalarValue::List( - Some(vec![ - ScalarValue::UInt64(Some(100)), - ScalarValue::UInt64(None), - ScalarValue::UInt64(Some(101)), - ]), - Arc::new(Field::new("item", DataType::UInt64, false)), - ) - .to_array(); + let values = vec![ + ScalarValue::UInt64(Some(100)), + ScalarValue::UInt64(None), + ScalarValue::UInt64(Some(101)), + ]; + let list_array_ref = ScalarValue::new_list(&values, &DataType::UInt64); + let list_array = as_list_array(&list_array_ref); + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 3); + + let prim_array_ref = list_array.value(0); + let prim_array = as_uint64_array(&prim_array_ref)?; + assert_eq!(prim_array.len(), 3); + assert_eq!(prim_array.value(0), 100); + assert!(prim_array.is_null(1)); + assert_eq!(prim_array.value(2), 101); + Ok(()) + } - let list_array = as_list_array(&list_array_ref)?; + #[test] + fn scalar_large_list_to_array() -> Result<()> { + let values = vec![ + ScalarValue::UInt64(Some(100)), + ScalarValue::UInt64(None), + ScalarValue::UInt64(Some(101)), + ]; + let list_array_ref = ScalarValue::new_large_list(&values, &DataType::UInt64); + let list_array = as_large_list_array(&list_array_ref); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 3); @@ -3577,6 +3935,78 @@ mod tests { ); } + #[test] + fn scalar_try_from_array_list_array_null() { + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + ]); + + let non_null_list_scalar = ScalarValue::try_from_array(&list, 0).unwrap(); + let null_list_scalar = ScalarValue::try_from_array(&list, 1).unwrap(); + + let data_type = + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + + assert_eq!(non_null_list_scalar.data_type(), data_type.clone()); + assert_eq!(null_list_scalar.data_type(), data_type); + } + + #[test] + fn scalar_try_from_list() { + let data_type = + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let data_type = &data_type; + let scalar: ScalarValue = data_type.try_into().unwrap(); + + let expected = ScalarValue::List(new_null_array( + &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + 1, + )); + + assert_eq!(expected, scalar) + } + + #[test] + fn scalar_try_from_list_of_list() { + let data_type = DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ))); + let data_type = &data_type; + let scalar: ScalarValue = data_type.try_into().unwrap(); + + let expected = ScalarValue::List(new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ))), + 1, + )); + + assert_eq!(expected, scalar) + } + + #[test] + fn scalar_try_from_not_equal_list_nested_list() { + let list_data_type = + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let data_type = &list_data_type; + let list_scalar: ScalarValue = data_type.try_into().unwrap(); + + let nested_list_data_type = DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ))); + let data_type = &nested_list_data_type; + let nested_list_scalar: ScalarValue = data_type.try_into().unwrap(); + + assert_ne!(list_scalar, nested_list_scalar); + } + #[test] fn scalar_try_from_dict_datatype() { let data_type = @@ -3826,7 +4256,9 @@ mod tests { for (index, scalar) in scalars.into_iter().enumerate() { assert!( - scalar.eq_array(&array, index), + scalar + .eq_array(&array, index) + .expect("Failed to compare arrays"), "Expected {scalar:?} to be equal to {array:?} at index {index}" ); @@ -3834,7 +4266,7 @@ mod tests { for other_index in 0..array.len() { if index != other_index { assert!( - !scalar.eq_array(&array, other_index), + !scalar.eq_array(&array, other_index).expect("Failed to compare arrays"), "Expected {scalar:?} to be NOT equal to {array:?} at index {other_index}" ); } @@ -3863,55 +4295,6 @@ mod tests { assert_eq!(Int64(Some(33)).partial_cmp(&Int32(Some(33))), None); assert_eq!(Int32(Some(33)).partial_cmp(&Int64(Some(33))), None); - assert_eq!( - List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - ) - .partial_cmp(&List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - )), - Some(Ordering::Equal) - ); - - assert_eq!( - List( - Some(vec![Int32(Some(10)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - ) - .partial_cmp(&List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - )), - Some(Ordering::Greater) - ); - - assert_eq!( - List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - ) - .partial_cmp(&List( - Some(vec![Int32(Some(10)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - )), - Some(Ordering::Less) - ); - - // For different data type, `partial_cmp` returns None. - assert_eq!( - List( - Some(vec![Int64(Some(1)), Int64(Some(5))]), - Arc::new(Field::new("item", DataType::Int64, false)), - ) - .partial_cmp(&List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - )), - None - ); - assert_eq!( ScalarValue::from(vec![ ("A", ScalarValue::from(1.0)), @@ -3938,6 +4321,16 @@ mod tests { ); } + #[test] + fn test_scalar_value_from_string() { + let scalar = ScalarValue::from("foo"); + assert_eq!(scalar, ScalarValue::Utf8(Some("foo".to_string()))); + let scalar = ScalarValue::from("foo".to_string()); + assert_eq!(scalar, ScalarValue::Utf8(Some("foo".to_string()))); + let scalar = ScalarValue::from_str("foo").unwrap(); + assert_eq!(scalar, ScalarValue::Utf8(Some("foo".to_string()))); + } + #[test] fn test_scalar_struct() { let field_a = Arc::new(Field::new("A", DataType::Int32, false)); @@ -3956,7 +4349,7 @@ mod tests { Some(vec![ ScalarValue::Int32(Some(23)), ScalarValue::Boolean(Some(false)), - ScalarValue::Utf8(Some("Hello".to_string())), + ScalarValue::from("Hello"), ScalarValue::from(vec![ ("e", ScalarValue::from(2i16)), ("f", ScalarValue::from(3i64)), @@ -3986,7 +4379,9 @@ mod tests { ); // Convert to length-2 array - let array = scalar.to_array_of_size(2); + let array = scalar + .to_array_of_size(2) + .expect("Failed to convert to array of size"); let expected = Arc::new(StructArray::from(vec![ ( @@ -4124,38 +4519,40 @@ mod tests { )); // Define primitive list scalars - let l0 = ScalarValue::List( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - Arc::new(Field::new("item", DataType::Int32, false)), - ); - - let l1 = ScalarValue::List( - Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - Arc::new(Field::new("item", DataType::Int32, false)), - ); - - let l2 = ScalarValue::List( - Some(vec![ScalarValue::from(6i32)]), - Arc::new(Field::new("item", DataType::Int32, false)), - ); + let l0 = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + let l1 = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(4), + Some(5), + ])]), + )); + let l2 = ScalarValue::List(Arc::new(ListArray::from_iter_primitive::< + Int32Type, + _, + _, + >(vec![Some(vec![Some(6)])]))); // Define struct scalars let s0 = ScalarValue::from(vec![ - ("A", ScalarValue::Utf8(Some(String::from("First")))), + ("A", ScalarValue::from("First")), ("primitive_list", l0), ]); let s1 = ScalarValue::from(vec![ - ("A", ScalarValue::Utf8(Some(String::from("Second")))), + ("A", ScalarValue::from("Second")), ("primitive_list", l1), ]); let s2 = ScalarValue::from(vec![ - ("A", ScalarValue::Utf8(Some(String::from("Third")))), + ("A", ScalarValue::from("Third")), ("primitive_list", l2), ]); @@ -4181,15 +4578,19 @@ mod tests { assert_eq!(array, &expected); // Define list-of-structs scalars - let nl0 = - ScalarValue::new_list(Some(vec![s0.clone(), s1.clone()]), s0.data_type()); - let nl1 = ScalarValue::new_list(Some(vec![s2]), s0.data_type()); + let nl0_array = ScalarValue::iter_to_array(vec![s0.clone(), s1.clone()]).unwrap(); + let nl0 = ScalarValue::List(Arc::new(array_into_list_array(nl0_array))); + + let nl1_array = ScalarValue::iter_to_array(vec![s2.clone()]).unwrap(); + let nl1 = ScalarValue::List(Arc::new(array_into_list_array(nl1_array))); + + let nl2_array = ScalarValue::iter_to_array(vec![s1.clone()]).unwrap(); + let nl2 = ScalarValue::List(Arc::new(array_into_list_array(nl2_array))); - let nl2 = ScalarValue::new_list(Some(vec![s1]), s0.data_type()); // iter_to_array for list-of-struct let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); - let array = as_list_array(&array).unwrap(); + let array = as_list_array(&array); // Construct expected array with array builders let field_a_builder = StringBuilder::with_capacity(4, 1024); @@ -4309,54 +4710,37 @@ mod tests { assert_eq!(array, &expected); } + fn build_2d_list(data: Vec>) -> ListArray { + let a1 = ListArray::from_iter_primitive::(vec![Some(data)]); + ListArray::new( + Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )), + OffsetBuffer::::from_lengths([1]), + Arc::new(a1), + None, + ) + } + #[test] fn test_nested_lists() { // Define inner list scalars - let l1 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - ); - - let l2 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ScalarValue::from(6i32)]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - ); - - let l3 = ScalarValue::new_list( - Some(vec![ScalarValue::new_list( - Some(vec![ScalarValue::from(9i32)]), - DataType::Int32, - )]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - ); - - let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); - let array = as_list_array(&array).unwrap(); + let arr1 = build_2d_list(vec![Some(1), Some(2), Some(3)]); + let arr2 = build_2d_list(vec![Some(4), Some(5)]); + let arr3 = build_2d_list(vec![Some(6)]); + + let array = ScalarValue::iter_to_array(vec![ + ScalarValue::List(Arc::new(arr1)), + ScalarValue::List(Arc::new(arr2)), + ScalarValue::List(Arc::new(arr3)), + ]) + .unwrap(); + let array = as_list_array(&array); // Construct expected array with array builders - let inner_builder = Int32Array::builder(8); + let inner_builder = Int32Array::builder(6); let middle_builder = ListBuilder::new(inner_builder); let mut outer_builder = ListBuilder::new(middle_builder); @@ -4364,6 +4748,7 @@ mod tests { outer_builder.values().values().append_value(2); outer_builder.values().values().append_value(3); outer_builder.values().append(true); + outer_builder.append(true); outer_builder.values().values().append_value(4); outer_builder.values().values().append_value(5); @@ -4372,14 +4757,6 @@ mod tests { outer_builder.values().values().append_value(6); outer_builder.values().append(true); - - outer_builder.values().values().append_value(7); - outer_builder.values().values().append_value(8); - outer_builder.values().append(true); - outer_builder.append(true); - - outer_builder.values().values().append_value(9); - outer_builder.values().append(true); outer_builder.append(true); let expected = outer_builder.finish(); @@ -4399,7 +4776,7 @@ mod tests { DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())) ); - let array = scalar.to_array(); + let array = scalar.to_array().expect("Failed to convert to array"); assert_eq!(array.len(), 1); assert_eq!( array.data_type(), @@ -4423,7 +4800,7 @@ mod tests { check_scalar_cast(ScalarValue::Float64(None), DataType::Int16); check_scalar_cast( - ScalarValue::Utf8(Some("foo".to_string())), + ScalarValue::from("foo"), DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), ); @@ -4436,7 +4813,7 @@ mod tests { // mimics how casting work on scalar values by `casting` `scalar` to `desired_type` fn check_scalar_cast(scalar: ScalarValue, desired_type: DataType) { // convert from scalar --> Array to call cast - let scalar_array = scalar.to_array(); + let scalar_array = scalar.to_array().expect("Failed to convert to array"); // cast the actual value let cast_array = kernels::cast::cast(&scalar_array, &desired_type).unwrap(); @@ -4445,7 +4822,9 @@ mod tests { assert_eq!(cast_scalar.data_type(), desired_type); // Some time later the "cast" scalar is turned back into an array: - let array = cast_scalar.to_array_of_size(10); + let array = cast_scalar + .to_array_of_size(10) + .expect("Failed to convert to array of size"); // The datatype should be "Dictionary" but is actually Utf8!!! assert_eq!(array.data_type(), &desired_type) @@ -4702,10 +5081,7 @@ mod tests { (ScalarValue::Int8(None), ScalarValue::Int16(Some(1))), (ScalarValue::Int8(Some(1)), ScalarValue::Int16(None)), // Unsupported types - ( - ScalarValue::Utf8(Some("foo".to_string())), - ScalarValue::Utf8(Some("bar".to_string())), - ), + (ScalarValue::from("foo"), ScalarValue::from("bar")), ( ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), @@ -4894,7 +5270,8 @@ mod tests { let arrays = scalars .iter() .map(ScalarValue::to_array) - .collect::>(); + .collect::>>() + .expect("Failed to convert to array"); let arrays = arrays.iter().map(|a| a.as_ref()).collect::>(); let array = concat(&arrays).unwrap(); check_array(array); @@ -4903,12 +5280,30 @@ mod tests { #[test] fn test_build_timestamp_millisecond_list() { let values = vec![ScalarValue::TimestampMillisecond(Some(1), None)]; - let ts_list = ScalarValue::new_list( - Some(values), - DataType::Timestamp(TimeUnit::Millisecond, None), + let arr = ScalarValue::new_list( + &values, + &DataType::Timestamp(TimeUnit::Millisecond, None), + ); + assert_eq!(1, arr.len()); + } + + #[test] + fn test_newlist_timestamp_zone() { + let s: &'static str = "UTC"; + let values = vec![ScalarValue::TimestampMillisecond(Some(1), Some(s.into()))]; + let arr = ScalarValue::new_list( + &values, + &DataType::Timestamp(TimeUnit::Millisecond, Some(s.into())), + ); + assert_eq!(1, arr.len()); + assert_eq!( + arr.data_type(), + &DataType::List(Arc::new(Field::new( + "item", + DataType::Timestamp(TimeUnit::Millisecond, Some(s.into())), + true + ))) ); - let list = ts_list.to_array_of_size(1); - assert_eq!(1, list.len()); } fn get_random_timestamps(sample_size: u64) -> Vec { diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index ca76e14cb8ab..7ad8992ca9ae 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -17,60 +17,309 @@ //! This module provides data structures to represent statistics -use std::fmt::Display; - -use arrow::datatypes::DataType; +use std::fmt::{self, Debug, Display}; use crate::ScalarValue; +use arrow_schema::Schema; + +/// Represents a value with a degree of certainty. `Precision` is used to +/// propagate information the precision of statistical values. +#[derive(Clone, PartialEq, Eq, Default)] +pub enum Precision { + /// The exact value is known + Exact(T), + /// The value is not known exactly, but is likely close to this value + Inexact(T), + /// Nothing is known about the value + #[default] + Absent, +} + +impl Precision { + /// If we have some value (exact or inexact), it returns that value. + /// Otherwise, it returns `None`. + pub fn get_value(&self) -> Option<&T> { + match self { + Precision::Exact(value) | Precision::Inexact(value) => Some(value), + Precision::Absent => None, + } + } + + /// Transform the value in this [`Precision`] object, if one exists, using + /// the given function. Preserves the exactness state. + pub fn map(self, f: F) -> Precision + where + F: Fn(T) -> T, + { + match self { + Precision::Exact(val) => Precision::Exact(f(val)), + Precision::Inexact(val) => Precision::Inexact(f(val)), + _ => self, + } + } + + /// Returns `Some(true)` if we have an exact value, `Some(false)` if we + /// have an inexact value, and `None` if there is no value. + pub fn is_exact(&self) -> Option { + match self { + Precision::Exact(_) => Some(true), + Precision::Inexact(_) => Some(false), + _ => None, + } + } + + /// Returns the maximum of two (possibly inexact) values, conservatively + /// propagating exactness information. If one of the input values is + /// [`Precision::Absent`], the result is `Absent` too. + pub fn max(&self, other: &Precision) -> Precision { + match (self, other) { + (Precision::Exact(a), Precision::Exact(b)) => { + Precision::Exact(if a >= b { a.clone() } else { b.clone() }) + } + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => { + Precision::Inexact(if a >= b { a.clone() } else { b.clone() }) + } + (_, _) => Precision::Absent, + } + } + + /// Returns the minimum of two (possibly inexact) values, conservatively + /// propagating exactness information. If one of the input values is + /// [`Precision::Absent`], the result is `Absent` too. + pub fn min(&self, other: &Precision) -> Precision { + match (self, other) { + (Precision::Exact(a), Precision::Exact(b)) => { + Precision::Exact(if a >= b { b.clone() } else { a.clone() }) + } + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => { + Precision::Inexact(if a >= b { b.clone() } else { a.clone() }) + } + (_, _) => Precision::Absent, + } + } + + /// Demotes the precision state from exact to inexact (if present). + pub fn to_inexact(self) -> Self { + match self { + Precision::Exact(value) => Precision::Inexact(value), + _ => self, + } + } +} + +impl Precision { + /// Calculates the sum of two (possibly inexact) [`usize`] values, + /// conservatively propagating exactness information. If one of the input + /// values is [`Precision::Absent`], the result is `Absent` too. + pub fn add(&self, other: &Precision) -> Precision { + match (self, other) { + (Precision::Exact(a), Precision::Exact(b)) => Precision::Exact(a + b), + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => Precision::Inexact(a + b), + (_, _) => Precision::Absent, + } + } + + /// Calculates the difference of two (possibly inexact) [`usize`] values, + /// conservatively propagating exactness information. If one of the input + /// values is [`Precision::Absent`], the result is `Absent` too. + pub fn sub(&self, other: &Precision) -> Precision { + match (self, other) { + (Precision::Exact(a), Precision::Exact(b)) => Precision::Exact(a - b), + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => Precision::Inexact(a - b), + (_, _) => Precision::Absent, + } + } + + /// Calculates the multiplication of two (possibly inexact) [`usize`] values, + /// conservatively propagating exactness information. If one of the input + /// values is [`Precision::Absent`], the result is `Absent` too. + pub fn multiply(&self, other: &Precision) -> Precision { + match (self, other) { + (Precision::Exact(a), Precision::Exact(b)) => Precision::Exact(a * b), + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => Precision::Inexact(a * b), + (_, _) => Precision::Absent, + } + } + + /// Return the estimate of applying a filter with estimated selectivity + /// `selectivity` to this Precision. A selectivity of `1.0` means that all + /// rows are selected. A selectivity of `0.5` means half the rows are + /// selected. Will always return inexact statistics. + pub fn with_estimated_selectivity(self, selectivity: f64) -> Self { + self.map(|v| ((v as f64 * selectivity).ceil()) as usize) + .to_inexact() + } +} + +impl Precision { + /// Calculates the sum of two (possibly inexact) [`ScalarValue`] values, + /// conservatively propagating exactness information. If one of the input + /// values is [`Precision::Absent`], the result is `Absent` too. + pub fn add(&self, other: &Precision) -> Precision { + match (self, other) { + (Precision::Exact(a), Precision::Exact(b)) => { + if let Ok(result) = a.add(b) { + Precision::Exact(result) + } else { + Precision::Absent + } + } + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => { + if let Ok(result) = a.add(b) { + Precision::Inexact(result) + } else { + Precision::Absent + } + } + (_, _) => Precision::Absent, + } + } +} + +impl Debug for Precision { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Precision::Exact(inner) => write!(f, "Exact({:?})", inner), + Precision::Inexact(inner) => write!(f, "Inexact({:?})", inner), + Precision::Absent => write!(f, "Absent"), + } + } +} + +impl Display for Precision { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Precision::Exact(inner) => write!(f, "Exact({:?})", inner), + Precision::Inexact(inner) => write!(f, "Inexact({:?})", inner), + Precision::Absent => write!(f, "Absent"), + } + } +} + /// Statistics for a relation /// Fields are optional and can be inexact because the sources /// sometimes provide approximate estimates for performance reasons /// and the transformations output are not always predictable. -#[derive(Debug, Clone, Default, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct Statistics { - /// The number of table rows - pub num_rows: Option, - /// total bytes of the table rows - pub total_byte_size: Option, - /// Statistics on a column level - pub column_statistics: Option>, - /// If true, any field that is `Some(..)` is the actual value in the data provided by the operator (it is not - /// an estimate). Any or all other fields might still be None, in which case no information is known. - /// if false, any field that is `Some(..)` may contain an inexact estimate and may not be the actual value. - pub is_exact: bool, + /// The number of table rows. + pub num_rows: Precision, + /// Total bytes of the table rows. + pub total_byte_size: Precision, + /// Statistics on a column level. It contains a [`ColumnStatistics`] for + /// each field in the schema of the the table to which the [`Statistics`] refer. + pub column_statistics: Vec, +} + +impl Statistics { + /// Returns a [`Statistics`] instance for the given schema by assigning + /// unknown statistics to each column in the schema. + pub fn new_unknown(schema: &Schema) -> Self { + Self { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: Statistics::unknown_column(schema), + } + } + + /// Returns an unbounded `ColumnStatistics` for each field in the schema. + pub fn unknown_column(schema: &Schema) -> Vec { + schema + .fields() + .iter() + .map(|_| ColumnStatistics::new_unknown()) + .collect() + } + + /// If the exactness of a [`Statistics`] instance is lost, this function relaxes + /// the exactness of all information by converting them [`Precision::Inexact`]. + pub fn into_inexact(self) -> Self { + Statistics { + num_rows: self.num_rows.to_inexact(), + total_byte_size: self.total_byte_size.to_inexact(), + column_statistics: self + .column_statistics + .into_iter() + .map(|cs| ColumnStatistics { + null_count: cs.null_count.to_inexact(), + max_value: cs.max_value.to_inexact(), + min_value: cs.min_value.to_inexact(), + distinct_count: cs.distinct_count.to_inexact(), + }) + .collect::>(), + } + } } impl Display for Statistics { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.num_rows.is_none() && self.total_byte_size.is_none() && !self.is_exact { - return Ok(()); - } + // string of column statistics + let column_stats = self + .column_statistics + .iter() + .enumerate() + .map(|(i, cs)| { + let s = format!("(Col[{}]:", i); + let s = if cs.min_value != Precision::Absent { + format!("{} Min={}", s, cs.min_value) + } else { + s + }; + let s = if cs.max_value != Precision::Absent { + format!("{} Max={}", s, cs.max_value) + } else { + s + }; + let s = if cs.null_count != Precision::Absent { + format!("{} Null={}", s, cs.null_count) + } else { + s + }; + let s = if cs.distinct_count != Precision::Absent { + format!("{} Distinct={}", s, cs.distinct_count) + } else { + s + }; - let rows = self - .num_rows - .map_or_else(|| "None".to_string(), |v| v.to_string()); - let bytes = self - .total_byte_size - .map_or_else(|| "None".to_string(), |v| v.to_string()); + s + ")" + }) + .collect::>() + .join(","); - write!(f, "rows={}, bytes={}, exact={}", rows, bytes, self.is_exact)?; + write!( + f, + "Rows={}, Bytes={}, [{}]", + self.num_rows, self.total_byte_size, column_stats + )?; Ok(()) } } /// Statistics for a column within a relation -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Default)] pub struct ColumnStatistics { /// Number of null values on column - pub null_count: Option, + pub null_count: Precision, /// Maximum value of column - pub max_value: Option, + pub max_value: Precision, /// Minimum value of column - pub min_value: Option, + pub min_value: Precision, /// Number of distinct values - pub distinct_count: Option, + pub distinct_count: Precision, } impl ColumnStatistics { @@ -78,19 +327,135 @@ impl ColumnStatistics { pub fn is_singleton(&self) -> bool { match (&self.min_value, &self.max_value) { // Min and max values are the same and not infinity. - (Some(min), Some(max)) => !min.is_null() && !max.is_null() && (min == max), + (Precision::Exact(min), Precision::Exact(max)) => { + !min.is_null() && !max.is_null() && (min == max) + } (_, _) => false, } } - /// Returns the [`ColumnStatistics`] corresponding to the given datatype by assigning infinite bounds. - pub fn new_with_unbounded_column(dt: &DataType) -> ColumnStatistics { - let null = ScalarValue::try_from(dt.clone()).ok(); + /// Returns a [`ColumnStatistics`] instance having all [`Precision::Absent`] parameters. + pub fn new_unknown() -> ColumnStatistics { ColumnStatistics { - null_count: None, - max_value: null.clone(), - min_value: null, - distinct_count: None, + null_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + distinct_count: Precision::Absent, } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_value() { + let exact_precision = Precision::Exact(42); + let inexact_precision = Precision::Inexact(23); + let absent_precision = Precision::::Absent; + + assert_eq!(*exact_precision.get_value().unwrap(), 42); + assert_eq!(*inexact_precision.get_value().unwrap(), 23); + assert_eq!(absent_precision.get_value(), None); + } + + #[test] + fn test_map() { + let exact_precision = Precision::Exact(42); + let inexact_precision = Precision::Inexact(23); + let absent_precision = Precision::Absent; + + let squared = |x| x * x; + + assert_eq!(exact_precision.map(squared), Precision::Exact(1764)); + assert_eq!(inexact_precision.map(squared), Precision::Inexact(529)); + assert_eq!(absent_precision.map(squared), Precision::Absent); + } + + #[test] + fn test_is_exact() { + let exact_precision = Precision::Exact(42); + let inexact_precision = Precision::Inexact(23); + let absent_precision = Precision::::Absent; + + assert_eq!(exact_precision.is_exact(), Some(true)); + assert_eq!(inexact_precision.is_exact(), Some(false)); + assert_eq!(absent_precision.is_exact(), None); + } + + #[test] + fn test_max() { + let precision1 = Precision::Exact(42); + let precision2 = Precision::Inexact(23); + let precision3 = Precision::Exact(30); + let absent_precision = Precision::Absent; + + assert_eq!(precision1.max(&precision2), Precision::Inexact(42)); + assert_eq!(precision1.max(&precision3), Precision::Exact(42)); + assert_eq!(precision2.max(&precision3), Precision::Inexact(30)); + assert_eq!(precision1.max(&absent_precision), Precision::Absent); + } + + #[test] + fn test_min() { + let precision1 = Precision::Exact(42); + let precision2 = Precision::Inexact(23); + let precision3 = Precision::Exact(30); + let absent_precision = Precision::Absent; + + assert_eq!(precision1.min(&precision2), Precision::Inexact(23)); + assert_eq!(precision1.min(&precision3), Precision::Exact(30)); + assert_eq!(precision2.min(&precision3), Precision::Inexact(23)); + assert_eq!(precision1.min(&absent_precision), Precision::Absent); + } + + #[test] + fn test_to_inexact() { + let exact_precision = Precision::Exact(42); + let inexact_precision = Precision::Inexact(42); + let absent_precision = Precision::::Absent; + + assert_eq!(exact_precision.clone().to_inexact(), inexact_precision); + assert_eq!(inexact_precision.clone().to_inexact(), inexact_precision); + assert_eq!(absent_precision.clone().to_inexact(), absent_precision); + } + + #[test] + fn test_add() { + let precision1 = Precision::Exact(42); + let precision2 = Precision::Inexact(23); + let precision3 = Precision::Exact(30); + let absent_precision = Precision::Absent; + + assert_eq!(precision1.add(&precision2), Precision::Inexact(65)); + assert_eq!(precision1.add(&precision3), Precision::Exact(72)); + assert_eq!(precision2.add(&precision3), Precision::Inexact(53)); + assert_eq!(precision1.add(&absent_precision), Precision::Absent); + } + + #[test] + fn test_sub() { + let precision1 = Precision::Exact(42); + let precision2 = Precision::Inexact(23); + let precision3 = Precision::Exact(30); + let absent_precision = Precision::Absent; + + assert_eq!(precision1.sub(&precision2), Precision::Inexact(19)); + assert_eq!(precision1.sub(&precision3), Precision::Exact(12)); + assert_eq!(precision1.sub(&absent_precision), Precision::Absent); + } + + #[test] + fn test_multiply() { + let precision1 = Precision::Exact(6); + let precision2 = Precision::Inexact(3); + let precision3 = Precision::Exact(5); + let absent_precision = Precision::Absent; + + assert_eq!(precision1.multiply(&precision2), Precision::Inexact(18)); + assert_eq!(precision1.multiply(&precision3), Precision::Exact(30)); + assert_eq!(precision2.multiply(&precision3), Precision::Inexact(15)); + assert_eq!(precision1.multiply(&absent_precision), Precision::Absent); + } +} diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index 60f1df7fd11a..eeace97eebfa 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -180,6 +180,7 @@ pub fn arrow_test_data() -> String { /// let filename = format!("{}/binary.parquet", testdata); /// assert!(std::path::PathBuf::from(filename).exists()); /// ``` +#[cfg(feature = "parquet")] pub fn parquet_test_data() -> String { match get_data_dir("PARQUET_TEST_DATA", "../../parquet-testing/data") { Ok(pb) => pb.display().to_string(), @@ -284,6 +285,7 @@ mod tests { } #[test] + #[cfg(feature = "parquet")] fn test_happy() { let res = arrow_test_data(); assert!(PathBuf::from(res).is_dir()); diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 2919d9a39c9c..5f11c8cc1d11 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -18,6 +18,7 @@ //! This module provides common traits for visiting or rewriting tree //! data structures easily. +use std::borrow::Cow; use std::sync::Arc; use crate::Result; @@ -32,7 +33,10 @@ use crate::Result; /// [`PhysicalExpr`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.PhysicalExpr.html /// [`LogicalPlan`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html /// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html -pub trait TreeNode: Sized { +pub trait TreeNode: Sized + Clone { + /// Returns all children of the TreeNode + fn children_nodes(&self) -> Vec>; + /// Use preorder to iterate the node on the tree so that we can /// stop fast for some cases. /// @@ -125,6 +129,17 @@ pub trait TreeNode: Sized { after_op.map_children(|node| node.transform_down(op)) } + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its + /// children(Preorder Traversal) using a mutable function, `F`. + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_down_mut(self, op: &mut F) -> Result + where + F: FnMut(Self) -> Result>, + { + let after_op = op(self)?.into(); + after_op.map_children(|node| node.transform_down_mut(op)) + } + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its /// children and then itself(Postorder Traversal). /// When the `op` does not apply to a given node, it is left unchanged. @@ -138,6 +153,19 @@ pub trait TreeNode: Sized { Ok(new_node) } + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its + /// children and then itself(Postorder Traversal) using a mutable function, `F`. + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_up_mut(self, op: &mut F) -> Result + where + F: FnMut(Self) -> Result>, + { + let after_op_children = self.map_children(|node| node.transform_up_mut(op))?; + + let new_node = op(after_op_children)?.into(); + Ok(new_node) + } + /// Transform the tree node using the given [TreeNodeRewriter] /// It performs a depth first walk of an node and its children. /// @@ -187,7 +215,17 @@ pub trait TreeNode: Sized { /// Apply the closure `F` to the node's children fn apply_children(&self, op: &mut F) -> Result where - F: FnMut(&Self) -> Result; + F: FnMut(&Self) -> Result, + { + for child in self.children_nodes() { + match op(&child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + Ok(VisitRecursion::Continue) + } /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result @@ -318,19 +356,8 @@ pub trait DynTreeNode { /// Blanket implementation for Arc for any tye that implements /// [`DynTreeNode`] (such as [`Arc`]) impl TreeNode for Arc { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.arc_children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.arc_children().into_iter().map(Cow::Owned).collect() } fn map_children(self, transform: F) -> Result diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index b7c80aa9ac44..0a61fce15482 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -17,12 +17,16 @@ //! This module provides the bisect function, which implements binary search. -use crate::{DataFusionError, Result, ScalarValue}; +use crate::error::{_internal_datafusion_err, _internal_err}; +use crate::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; use arrow::array::{ArrayRef, PrimitiveArray}; +use arrow::buffer::OffsetBuffer; use arrow::compute; use arrow::compute::{partition, SortColumn, SortOptions}; -use arrow::datatypes::{SchemaRef, UInt32Type}; +use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; +use arrow_array::{Array, LargeListArray, ListArray, RecordBatchOptions}; +use arrow_schema::DataType; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; @@ -86,8 +90,12 @@ pub fn get_record_batch_at_indices( indices: &PrimitiveArray, ) -> Result { let new_columns = get_arrayref_at_indices(record_batch.columns(), indices)?; - RecordBatch::try_new(record_batch.schema(), new_columns) - .map_err(DataFusionError::ArrowError) + RecordBatch::try_new_with_options( + record_batch.schema(), + new_columns, + &RecordBatchOptions::new().with_row_count(Some(indices.len())), + ) + .map_err(|e| arrow_datafusion_err!(e)) } /// This function compares two tuples depending on the given sort options. @@ -109,7 +117,7 @@ pub fn compare_rows( lhs.partial_cmp(rhs) } .ok_or_else(|| { - DataFusionError::Internal("Column array shouldn't be empty".to_string()) + _internal_datafusion_err!("Column array shouldn't be empty") })?, (true, true, _) => continue, }; @@ -131,7 +139,7 @@ pub fn bisect( ) -> Result { let low: usize = 0; let high: usize = item_columns - .get(0) + .first() .ok_or_else(|| { DataFusionError::Internal("Column array shouldn't be empty".to_string()) })? @@ -182,7 +190,7 @@ pub fn linear_search( ) -> Result { let low: usize = 0; let high: usize = item_columns - .get(0) + .first() .ok_or_else(|| { DataFusionError::Internal("Column array shouldn't be empty".to_string()) })? @@ -283,7 +291,7 @@ pub fn get_arrayref_at_indices( indices, None, // None: no index check ) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect() } @@ -334,6 +342,156 @@ pub fn longest_consecutive_prefix>( count } +/// Array Utils + +/// Wrap an array into a single element `ListArray`. +/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` +pub fn array_into_list_array(arr: ArrayRef) -> ListArray { + let offsets = OffsetBuffer::from_lengths([arr.len()]); + ListArray::new( + Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + offsets, + arr, + None, + ) +} + +/// Wrap an array into a single element `LargeListArray`. +/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` +pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { + let offsets = OffsetBuffer::from_lengths([arr.len()]); + LargeListArray::new( + Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + offsets, + arr, + None, + ) +} + +/// Wrap arrays into a single element `ListArray`. +/// +/// Example: +/// ``` +/// use arrow::array::{Int32Array, ListArray, ArrayRef}; +/// use arrow::datatypes::{Int32Type, Field}; +/// use std::sync::Arc; +/// +/// let arr1 = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; +/// let arr2 = Arc::new(Int32Array::from(vec![4, 5, 6])) as ArrayRef; +/// +/// let list_arr = datafusion_common::utils::arrays_into_list_array([arr1, arr2]).unwrap(); +/// +/// let expected = ListArray::from_iter_primitive::( +/// vec![ +/// Some(vec![Some(1), Some(2), Some(3)]), +/// Some(vec![Some(4), Some(5), Some(6)]), +/// ] +/// ); +/// +/// assert_eq!(list_arr, expected); +pub fn arrays_into_list_array( + arr: impl IntoIterator, +) -> Result { + let arr = arr.into_iter().collect::>(); + if arr.is_empty() { + return _internal_err!("Cannot wrap empty array into list array"); + } + + let lens = arr.iter().map(|x| x.len()).collect::>(); + // Assume data type is consistent + let data_type = arr[0].data_type().to_owned(); + let values = arr.iter().map(|x| x.as_ref()).collect::>(); + Ok(ListArray::new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::from_lengths(lens), + arrow::compute::concat(values.as_slice())?, + None, + )) +} + +/// Get the base type of a data type. +/// +/// Example +/// ``` +/// use arrow::datatypes::{DataType, Field}; +/// use datafusion_common::utils::base_type; +/// use std::sync::Arc; +/// +/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// assert_eq!(base_type(&data_type), DataType::Int32); +/// +/// let data_type = DataType::Int32; +/// assert_eq!(base_type(&data_type), DataType::Int32); +/// ``` +pub fn base_type(data_type: &DataType) -> DataType { + match data_type { + DataType::List(field) | DataType::LargeList(field) => { + base_type(field.data_type()) + } + _ => data_type.to_owned(), + } +} + +/// A helper function to coerce base type in List. +/// +/// Example +/// ``` +/// use arrow::datatypes::{DataType, Field}; +/// use datafusion_common::utils::coerced_type_with_base_type_only; +/// use std::sync::Arc; +/// +/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let base_type = DataType::Float64; +/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type); +/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new("item", DataType::Float64, true)))); +pub fn coerced_type_with_base_type_only( + data_type: &DataType, + base_type: &DataType, +) -> DataType { + match data_type { + DataType::List(field) => { + let data_type = match field.data_type() { + DataType::List(_) => { + coerced_type_with_base_type_only(field.data_type(), base_type) + } + _ => base_type.to_owned(), + }; + + DataType::List(Arc::new(Field::new( + field.name(), + data_type, + field.is_nullable(), + ))) + } + DataType::LargeList(field) => { + let data_type = match field.data_type() { + DataType::LargeList(_) => { + coerced_type_with_base_type_only(field.data_type(), base_type) + } + _ => base_type.to_owned(), + }; + + DataType::LargeList(Arc::new(Field::new( + field.name(), + data_type, + field.is_nullable(), + ))) + } + + _ => base_type.clone(), + } +} + +/// Compute the number of dimensions in a list data type. +pub fn list_ndims(data_type: &DataType) -> u64 { + match data_type { + DataType::List(field) | DataType::LargeList(field) => { + 1 + list_ndims(field.data_type()) + } + _ => 0, + } +} + /// An extension trait for smart pointers. Provides an interface to get a /// raw pointer to the data (with metadata stripped away). /// diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 266ff855752b..9de6a7f7d6a0 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -20,9 +20,9 @@ name = "datafusion" description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model" keywords = ["arrow", "query", "sql"] include = ["benches/*.rs", "src/**/*.rs", "Cargo.toml"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -39,12 +39,14 @@ avro = ["apache-avro", "num-traits", "datafusion-common/avro"] backtrace = ["datafusion-common/backtrace"] compression = ["xz2", "bzip2", "flate2", "zstd", "async-compression"] crypto_expressions = ["datafusion-physical-expr/crypto_expressions", "datafusion-optimizer/crypto_expressions"] -default = ["crypto_expressions", "encoding__expressions", "regex_expressions", "unicode_expressions", "compression"] -encoding__expressions = ["datafusion-physical-expr/encoding_expressions"] +default = ["crypto_expressions", "encoding_expressions", "regex_expressions", "unicode_expressions", "compression", "parquet"] +encoding_expressions = ["datafusion-physical-expr/encoding_expressions"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = [] -pyarrow = ["datafusion-common/pyarrow"] +parquet = ["datafusion-common/parquet", "dep:parquet"] +pyarrow = ["datafusion-common/pyarrow", "parquet"] regex_expressions = ["datafusion-physical-expr/regex_expressions", "datafusion-optimizer/regex_expressions"] +serde = ["arrow-schema/serde"] simd = ["arrow/simd"] unicode_expressions = ["datafusion-physical-expr/unicode_expressions", "datafusion-optimizer/unicode_expressions", "datafusion-sql/unicode_expressions"] @@ -53,64 +55,64 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] apache-avro = { version = "0.16", optional = true } arrow = { workspace = true } arrow-array = { workspace = true } +arrow-ipc = { workspace = true } arrow-schema = { workspace = true } async-compression = { version = "0.4.0", features = ["bzip2", "gzip", "xz", "zstd", "futures-io", "tokio"], optional = true } -async-trait = "0.1.73" -bytes = "1.4" +async-trait = { workspace = true } +bytes = { workspace = true } bzip2 = { version = "0.4.3", optional = true } chrono = { workspace = true } -dashmap = "5.4.0" -datafusion-common = { path = "../common", version = "31.0.0", features = ["parquet", "object_store"] } -datafusion-execution = { path = "../execution", version = "31.0.0" } -datafusion-expr = { path = "../expr", version = "31.0.0" } -datafusion-optimizer = { path = "../optimizer", version = "31.0.0", default-features = false } -datafusion-physical-expr = { path = "../physical-expr", version = "31.0.0", default-features = false } -datafusion-physical-plan = { path = "../physical-plan", version = "31.0.0", default-features = false } -datafusion-sql = { path = "../sql", version = "31.0.0" } +dashmap = { workspace = true } +datafusion-common = { path = "../common", version = "34.0.0", features = ["object_store"], default-features = false } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-optimizer = { path = "../optimizer", version = "34.0.0", default-features = false } +datafusion-physical-expr = { path = "../physical-expr", version = "34.0.0", default-features = false } +datafusion-physical-plan = { workspace = true } +datafusion-sql = { workspace = true } flate2 = { version = "1.0.24", optional = true } -futures = "0.3" +futures = { workspace = true } glob = "0.3.0" half = { version = "2.1", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } -indexmap = "2.0.0" -itertools = "0.11" -log = "^0.4" +indexmap = { workspace = true } +itertools = { workspace = true } +log = { workspace = true } num-traits = { version = "0.2", optional = true } -num_cpus = "1.13.0" -object_store = "0.7.0" -parking_lot = "0.12" -parquet = { workspace = true } -percent-encoding = "2.2.0" +num_cpus = { workspace = true } +object_store = { workspace = true } +parking_lot = { workspace = true } +parquet = { workspace = true, optional = true, default-features = true } pin-project-lite = "^0.2.7" -rand = "0.8" +rand = { workspace = true } sqlparser = { workspace = true } -tempfile = "3" +tempfile = { workspace = true } tokio = { version = "1.28", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-util = { version = "0.7.4", features = ["io"] } -url = "2.2" +url = { workspace = true } uuid = { version = "1.0", features = ["v4"] } xz2 = { version = "0.1", optional = true } -zstd = { version = "0.12", optional = true, default-features = false } - +zstd = { version = "0.13", optional = true, default-features = false } [dev-dependencies] -async-trait = "0.1.53" -bigdecimal = "0.4.1" +async-trait = { workspace = true } +bigdecimal = { workspace = true } criterion = { version = "0.5", features = ["async_tokio"] } csv = "1.1.6" -ctor = "0.2.0" -doc-comment = "0.3" -env_logger = "0.10" -half = "2.2.1" +ctor = { workspace = true } +doc-comment = { workspace = true } +env_logger = { workspace = true } +half = { workspace = true } postgres-protocol = "0.6.4" postgres-types = { version = "0.2.4", features = ["derive", "with-chrono-0_4"] } rand = { version = "0.8", features = ["small_rng"] } rand_distr = "0.4.3" regex = "1.5.4" -rstest = "0.18.0" +rstest = { workspace = true } rust_decimal = { version = "1.27.0", features = ["tokio-pg"] } +serde_json = { workspace = true } test-utils = { path = "../../test-utils" } -thiserror = "1.0.37" +thiserror = { workspace = true } tokio-postgres = "0.7.7" [target.'cfg(not(target_os = "windows"))'.dev-dependencies] nix = { version = "0.27.1", features = ["fs"] } @@ -119,6 +121,10 @@ nix = { version = "0.27.1", features = ["fs"] } harness = false name = "aggregate_query_sql" +[[bench]] +harness = false +name = "distinct_query_sql" + [[bench]] harness = false name = "sort_limit_query_sql" @@ -162,3 +168,7 @@ name = "sort" [[bench]] harness = false name = "topk_aggregate" + +[[bench]] +harness = false +name = "array_expression" diff --git a/datafusion/core/README.md b/datafusion/core/README.md new file mode 100644 index 000000000000..5a9493d086cd --- /dev/null +++ b/datafusion/core/README.md @@ -0,0 +1,26 @@ + + +# DataFusion Common + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate contains the main entrypoints and high level DataFusion APIs such as SessionContext, and DataFrame and ListingTable. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/core/benches/array_expression.rs b/datafusion/core/benches/array_expression.rs new file mode 100644 index 000000000000..95bc93e0e353 --- /dev/null +++ b/datafusion/core/benches/array_expression.rs @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +extern crate arrow; +extern crate datafusion; + +mod data_utils; +use crate::criterion::Criterion; +use arrow_array::cast::AsArray; +use arrow_array::types::Int64Type; +use arrow_array::{ArrayRef, Int64Array, ListArray}; +use datafusion_physical_expr::array_expressions; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + // Construct large arrays for benchmarking + + let array_len = 100000000; + + let array = (0..array_len).map(|_| Some(2_i64)).collect::>(); + let list_array = ListArray::from_iter_primitive::(vec![ + Some(array.clone()), + Some(array.clone()), + Some(array), + ]); + let from_array = Int64Array::from_value(2, 3); + let to_array = Int64Array::from_value(-2, 3); + + let args = vec![ + Arc::new(list_array) as ArrayRef, + Arc::new(from_array) as ArrayRef, + Arc::new(to_array) as ArrayRef, + ]; + + let array = (0..array_len).map(|_| Some(-2_i64)).collect::>(); + let expected_array = ListArray::from_iter_primitive::(vec![ + Some(array.clone()), + Some(array.clone()), + Some(array), + ]); + + // Benchmark array functions + + c.bench_function("array_replace", |b| { + b.iter(|| { + assert_eq!( + array_expressions::array_replace_all(args.as_slice()) + .unwrap() + .as_list::(), + criterion::black_box(&expected_array) + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/data_utils/mod.rs b/datafusion/core/benches/data_utils/mod.rs index 64c0e4b100a1..9d2864919225 100644 --- a/datafusion/core/benches/data_utils/mod.rs +++ b/datafusion/core/benches/data_utils/mod.rs @@ -25,11 +25,16 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; +use arrow_array::builder::{Int64Builder, StringBuilder}; use datafusion::datasource::MemTable; use datafusion::error::Result; +use datafusion_common::DataFusionError; use rand::rngs::StdRng; use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; +use rand_distr::Distribution; +use rand_distr::{Normal, Pareto}; +use std::fmt::Write; use std::sync::Arc; /// create an in-memory table given the partition len, array len, and batch size, @@ -156,3 +161,83 @@ pub fn create_record_batches( }) .collect::>() } + +/// Create time series data with `partition_cnt` partitions and `sample_cnt` rows per partition +/// in ascending order, if `asc` is true, otherwise randomly sampled using a Pareto distribution +#[allow(dead_code)] +pub(crate) fn make_data( + partition_cnt: i32, + sample_cnt: i32, + asc: bool, +) -> Result<(Arc, Vec>), DataFusionError> { + // constants observed from trace data + let simultaneous_group_cnt = 2000; + let fitted_shape = 12f64; + let fitted_scale = 5f64; + let mean = 0.1; + let stddev = 1.1; + let pareto = Pareto::new(fitted_scale, fitted_shape).unwrap(); + let normal = Normal::new(mean, stddev).unwrap(); + let mut rng = rand::rngs::SmallRng::from_seed([0; 32]); + + // populate data + let schema = test_schema(); + let mut partitions = vec![]; + let mut cur_time = 16909000000000i64; + for _ in 0..partition_cnt { + let mut id_builder = StringBuilder::new(); + let mut ts_builder = Int64Builder::new(); + let gen_id = |rng: &mut rand::rngs::SmallRng| { + rng.gen::<[u8; 16]>() + .iter() + .fold(String::new(), |mut output, b| { + let _ = write!(output, "{b:02X}"); + output + }) + }; + let gen_sample_cnt = + |mut rng: &mut rand::rngs::SmallRng| pareto.sample(&mut rng).ceil() as u32; + let mut group_ids = (0..simultaneous_group_cnt) + .map(|_| gen_id(&mut rng)) + .collect::>(); + let mut group_sample_cnts = (0..simultaneous_group_cnt) + .map(|_| gen_sample_cnt(&mut rng)) + .collect::>(); + for _ in 0..sample_cnt { + let random_index = rng.gen_range(0..simultaneous_group_cnt); + let trace_id = &mut group_ids[random_index]; + let sample_cnt = &mut group_sample_cnts[random_index]; + *sample_cnt -= 1; + if *sample_cnt == 0 { + *trace_id = gen_id(&mut rng); + *sample_cnt = gen_sample_cnt(&mut rng); + } + + id_builder.append_value(trace_id); + ts_builder.append_value(cur_time); + + if asc { + cur_time += 1; + } else { + let samp: f64 = normal.sample(&mut rng); + let samp = samp.round(); + cur_time += samp as i64; + } + } + + // convert to MemTable + let id_col = Arc::new(id_builder.finish()); + let ts_col = Arc::new(ts_builder.finish()); + let batch = RecordBatch::try_new(schema.clone(), vec![id_col, ts_col])?; + partitions.push(vec![batch]); + } + Ok((schema, partitions)) +} + +/// The Schema used by make_data +fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Utf8, false), + Field::new("timestamp_ms", DataType::Int64, false), + ])) +} diff --git a/datafusion/core/benches/distinct_query_sql.rs b/datafusion/core/benches/distinct_query_sql.rs new file mode 100644 index 000000000000..c242798a56f0 --- /dev/null +++ b/datafusion/core/benches/distinct_query_sql.rs @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +extern crate arrow; +extern crate datafusion; + +mod data_utils; +use crate::criterion::Criterion; +use data_utils::{create_table_provider, make_data}; +use datafusion::execution::context::SessionContext; +use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::{datasource::MemTable, error::Result}; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::TaskContext; + +use parking_lot::Mutex; +use std::{sync::Arc, time::Duration}; +use tokio::runtime::Runtime; + +fn query(ctx: Arc>, sql: &str) { + let rt = Runtime::new().unwrap(); + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); + criterion::black_box(rt.block_on(df.collect()).unwrap()); +} + +fn create_context( + partitions_len: usize, + array_len: usize, + batch_size: usize, +) -> Result>> { + let ctx = SessionContext::new(); + let provider = create_table_provider(partitions_len, array_len, batch_size)?; + ctx.register_table("t", provider)?; + Ok(Arc::new(Mutex::new(ctx))) +} + +fn criterion_benchmark_limited_distinct(c: &mut Criterion) { + let partitions_len = 10; + let array_len = 1 << 26; // 64 M + let batch_size = 8192; + let ctx = create_context(partitions_len, array_len, batch_size).unwrap(); + + let mut group = c.benchmark_group("custom-measurement-time"); + group.measurement_time(Duration::from_secs(40)); + + group.bench_function("distinct_group_by_u64_narrow_limit_10", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 10", + ) + }) + }); + + group.bench_function("distinct_group_by_u64_narrow_limit_100", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 100", + ) + }) + }); + + group.bench_function("distinct_group_by_u64_narrow_limit_1000", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 1000", + ) + }) + }); + + group.bench_function("distinct_group_by_u64_narrow_limit_10000", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 10000", + ) + }) + }); + + group.bench_function("group_by_multiple_columns_limit_10", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT u64_narrow, u64_wide, utf8, f64 FROM t GROUP BY 1, 2, 3, 4 LIMIT 10", + ) + }) + }); + group.finish(); +} + +async fn distinct_with_limit( + plan: Arc, + ctx: Arc, +) -> Result<()> { + let batches = collect(plan, ctx).await?; + assert_eq!(batches.len(), 1); + let batch = batches.first().unwrap(); + assert_eq!(batch.num_rows(), 10); + + Ok(()) +} + +fn run(plan: Arc, ctx: Arc) { + let rt = Runtime::new().unwrap(); + criterion::black_box( + rt.block_on(async { distinct_with_limit(plan.clone(), ctx.clone()).await }), + ) + .unwrap(); +} + +pub async fn create_context_sampled_data( + sql: &str, + partition_cnt: i32, + sample_cnt: i32, +) -> Result<(Arc, Arc)> { + let (schema, parts) = make_data(partition_cnt, sample_cnt, false /* asc */).unwrap(); + let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); + + // Create the DataFrame + let cfg = SessionConfig::new(); + let ctx = SessionContext::new_with_config(cfg); + let _ = ctx.register_table("traces", mem_table)?; + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + Ok((physical_plan, ctx.task_ctx())) +} + +fn criterion_benchmark_limited_distinct_sampled(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + let limit = 10; + let partitions = 100; + let samples = 100_000; + let sql = + format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); + + let distinct_trace_id_100_partitions_100_000_samples_limit_100 = rt.block_on(async { + create_context_sampled_data(sql.as_str(), partitions, samples) + .await + .unwrap() + }); + + c.bench_function( + format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + |b| b.iter(|| run(distinct_trace_id_100_partitions_100_000_samples_limit_100.0.clone(), + distinct_trace_id_100_partitions_100_000_samples_limit_100.1.clone())), + ); + + let partitions = 10; + let samples = 1_000_000; + let sql = + format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); + + let distinct_trace_id_10_partitions_1_000_000_samples_limit_10 = rt.block_on(async { + create_context_sampled_data(sql.as_str(), partitions, samples) + .await + .unwrap() + }); + + c.bench_function( + format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + |b| b.iter(|| run(distinct_trace_id_10_partitions_1_000_000_samples_limit_10.0.clone(), + distinct_trace_id_10_partitions_1_000_000_samples_limit_10.1.clone())), + ); + + let partitions = 1; + let samples = 10_000_000; + let sql = + format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); + + let rt = Runtime::new().unwrap(); + let distinct_trace_id_1_partition_10_000_000_samples_limit_10 = rt.block_on(async { + create_context_sampled_data(sql.as_str(), partitions, samples) + .await + .unwrap() + }); + + c.bench_function( + format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + |b| b.iter(|| run(distinct_trace_id_1_partition_10_000_000_samples_limit_10.0.clone(), + distinct_trace_id_1_partition_10_000_000_samples_limit_10.1.clone())), + ); +} + +criterion_group!( + benches, + criterion_benchmark_limited_distinct, + criterion_benchmark_limited_distinct_sampled +); +criterion_main!(benches); diff --git a/datafusion/core/benches/parquet_query_sql.rs b/datafusion/core/benches/parquet_query_sql.rs index 876b1fe7e198..6c9ab315761e 100644 --- a/datafusion/core/benches/parquet_query_sql.rs +++ b/datafusion/core/benches/parquet_query_sql.rs @@ -193,7 +193,7 @@ fn criterion_benchmark(c: &mut Criterion) { let partitions = 4; let config = SessionConfig::new().with_target_partitions(partitions); - let context = SessionContext::with_config(config); + let context = SessionContext::new_with_config(config); let local_rt = tokio::runtime::Builder::new_current_thread() .build() diff --git a/datafusion/core/benches/scalar.rs b/datafusion/core/benches/scalar.rs index 30f21a964d5f..540f7212e96e 100644 --- a/datafusion/core/benches/scalar.rs +++ b/datafusion/core/benches/scalar.rs @@ -22,7 +22,15 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_array_of_size 100000", |b| { let scalar = ScalarValue::Int32(Some(100)); - b.iter(|| assert_eq!(scalar.to_array_of_size(100000).null_count(), 0)) + b.iter(|| { + assert_eq!( + scalar + .to_array_of_size(100000) + .expect("Failed to convert to array of size") + .null_count(), + 0 + ) + }) }); } diff --git a/datafusion/core/benches/sort_limit_query_sql.rs b/datafusion/core/benches/sort_limit_query_sql.rs index 62160067143e..cfd4b8bc4bba 100644 --- a/datafusion/core/benches/sort_limit_query_sql.rs +++ b/datafusion/core/benches/sort_limit_query_sql.rs @@ -86,8 +86,9 @@ fn create_context() -> Arc> { rt.block_on(async { // create local session context - let ctx = - SessionContext::with_config(SessionConfig::new().with_target_partitions(1)); + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(1), + ); let table_provider = Arc::new(csv.await); let mem_table = MemTable::load(table_provider, Some(partitions), &ctx.state()) @@ -98,7 +99,7 @@ fn create_context() -> Arc> { ctx_holder.lock().push(Arc::new(Mutex::new(ctx))) }); - let ctx = ctx_holder.lock().get(0).unwrap().clone(); + let ctx = ctx_holder.lock().first().unwrap().clone(); ctx } diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 7a41b6bec6f5..1754129a768f 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -60,6 +60,104 @@ pub fn create_table_provider(column_prefix: &str, num_columns: usize) -> Arc [(String, Schema); 8] { + let lineitem_schema = Schema::new(vec![ + Field::new("l_orderkey", DataType::Int64, false), + Field::new("l_partkey", DataType::Int64, false), + Field::new("l_suppkey", DataType::Int64, false), + Field::new("l_linenumber", DataType::Int32, false), + Field::new("l_quantity", DataType::Decimal128(15, 2), false), + Field::new("l_extendedprice", DataType::Decimal128(15, 2), false), + Field::new("l_discount", DataType::Decimal128(15, 2), false), + Field::new("l_tax", DataType::Decimal128(15, 2), false), + Field::new("l_returnflag", DataType::Utf8, false), + Field::new("l_linestatus", DataType::Utf8, false), + Field::new("l_shipdate", DataType::Date32, false), + Field::new("l_commitdate", DataType::Date32, false), + Field::new("l_receiptdate", DataType::Date32, false), + Field::new("l_shipinstruct", DataType::Utf8, false), + Field::new("l_shipmode", DataType::Utf8, false), + Field::new("l_comment", DataType::Utf8, false), + ]); + + let orders_schema = Schema::new(vec![ + Field::new("o_orderkey", DataType::Int64, false), + Field::new("o_custkey", DataType::Int64, false), + Field::new("o_orderstatus", DataType::Utf8, false), + Field::new("o_totalprice", DataType::Decimal128(15, 2), false), + Field::new("o_orderdate", DataType::Date32, false), + Field::new("o_orderpriority", DataType::Utf8, false), + Field::new("o_clerk", DataType::Utf8, false), + Field::new("o_shippriority", DataType::Int32, false), + Field::new("o_comment", DataType::Utf8, false), + ]); + + let part_schema = Schema::new(vec![ + Field::new("p_partkey", DataType::Int64, false), + Field::new("p_name", DataType::Utf8, false), + Field::new("p_mfgr", DataType::Utf8, false), + Field::new("p_brand", DataType::Utf8, false), + Field::new("p_type", DataType::Utf8, false), + Field::new("p_size", DataType::Int32, false), + Field::new("p_container", DataType::Utf8, false), + Field::new("p_retailprice", DataType::Decimal128(15, 2), false), + Field::new("p_comment", DataType::Utf8, false), + ]); + + let supplier_schema = Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, false), + Field::new("s_name", DataType::Utf8, false), + Field::new("s_address", DataType::Utf8, false), + Field::new("s_nationkey", DataType::Int64, false), + Field::new("s_phone", DataType::Utf8, false), + Field::new("s_acctbal", DataType::Decimal128(15, 2), false), + Field::new("s_comment", DataType::Utf8, false), + ]); + + let partsupp_schema = Schema::new(vec![ + Field::new("ps_partkey", DataType::Int64, false), + Field::new("ps_suppkey", DataType::Int64, false), + Field::new("ps_availqty", DataType::Int32, false), + Field::new("ps_supplycost", DataType::Decimal128(15, 2), false), + Field::new("ps_comment", DataType::Utf8, false), + ]); + + let customer_schema = Schema::new(vec![ + Field::new("c_custkey", DataType::Int64, false), + Field::new("c_name", DataType::Utf8, false), + Field::new("c_address", DataType::Utf8, false), + Field::new("c_nationkey", DataType::Int64, false), + Field::new("c_phone", DataType::Utf8, false), + Field::new("c_acctbal", DataType::Decimal128(15, 2), false), + Field::new("c_mktsegment", DataType::Utf8, false), + Field::new("c_comment", DataType::Utf8, false), + ]); + + let nation_schema = Schema::new(vec![ + Field::new("n_nationkey", DataType::Int64, false), + Field::new("n_name", DataType::Utf8, false), + Field::new("n_regionkey", DataType::Int64, false), + Field::new("n_comment", DataType::Utf8, false), + ]); + + let region_schema = Schema::new(vec![ + Field::new("r_regionkey", DataType::Int64, false), + Field::new("r_name", DataType::Utf8, false), + Field::new("r_comment", DataType::Utf8, false), + ]); + + [ + ("lineitem".to_string(), lineitem_schema), + ("orders".to_string(), orders_schema), + ("part".to_string(), part_schema), + ("supplier".to_string(), supplier_schema), + ("partsupp".to_string(), partsupp_schema), + ("customer".to_string(), customer_schema), + ("nation".to_string(), nation_schema), + ("region".to_string(), region_schema), + ] +} + fn create_context() -> SessionContext { let ctx = SessionContext::new(); ctx.register_table("t1", create_table_provider("a", 200)) @@ -68,6 +166,16 @@ fn create_context() -> SessionContext { .unwrap(); ctx.register_table("t700", create_table_provider("c", 700)) .unwrap(); + + let tpch_schemas = create_tpch_schemas(); + tpch_schemas.iter().for_each(|(name, schema)| { + ctx.register_table( + name, + Arc::new(MemTable::try_new(Arc::new(schema.clone()), vec![]).unwrap()), + ) + .unwrap(); + }); + ctx } @@ -115,6 +223,54 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + let q1_sql = std::fs::read_to_string("../../benchmarks/queries/q1.sql").unwrap(); + let q2_sql = std::fs::read_to_string("../../benchmarks/queries/q2.sql").unwrap(); + let q3_sql = std::fs::read_to_string("../../benchmarks/queries/q3.sql").unwrap(); + let q4_sql = std::fs::read_to_string("../../benchmarks/queries/q4.sql").unwrap(); + let q5_sql = std::fs::read_to_string("../../benchmarks/queries/q5.sql").unwrap(); + let q6_sql = std::fs::read_to_string("../../benchmarks/queries/q6.sql").unwrap(); + let q7_sql = std::fs::read_to_string("../../benchmarks/queries/q7.sql").unwrap(); + let q8_sql = std::fs::read_to_string("../../benchmarks/queries/q8.sql").unwrap(); + let q9_sql = std::fs::read_to_string("../../benchmarks/queries/q9.sql").unwrap(); + let q10_sql = std::fs::read_to_string("../../benchmarks/queries/q10.sql").unwrap(); + let q11_sql = std::fs::read_to_string("../../benchmarks/queries/q11.sql").unwrap(); + let q12_sql = std::fs::read_to_string("../../benchmarks/queries/q12.sql").unwrap(); + let q13_sql = std::fs::read_to_string("../../benchmarks/queries/q13.sql").unwrap(); + let q14_sql = std::fs::read_to_string("../../benchmarks/queries/q14.sql").unwrap(); + // let q15_sql = std::fs::read_to_string("../../benchmarks/queries/q15.sql").unwrap(); + let q16_sql = std::fs::read_to_string("../../benchmarks/queries/q16.sql").unwrap(); + let q17_sql = std::fs::read_to_string("../../benchmarks/queries/q17.sql").unwrap(); + let q18_sql = std::fs::read_to_string("../../benchmarks/queries/q18.sql").unwrap(); + let q19_sql = std::fs::read_to_string("../../benchmarks/queries/q19.sql").unwrap(); + let q20_sql = std::fs::read_to_string("../../benchmarks/queries/q20.sql").unwrap(); + let q21_sql = std::fs::read_to_string("../../benchmarks/queries/q21.sql").unwrap(); + let q22_sql = std::fs::read_to_string("../../benchmarks/queries/q22.sql").unwrap(); + + c.bench_function("physical_plan_tpch", |b| { + b.iter(|| physical_plan(&ctx, &q1_sql)); + b.iter(|| physical_plan(&ctx, &q2_sql)); + b.iter(|| physical_plan(&ctx, &q3_sql)); + b.iter(|| physical_plan(&ctx, &q4_sql)); + b.iter(|| physical_plan(&ctx, &q5_sql)); + b.iter(|| physical_plan(&ctx, &q6_sql)); + b.iter(|| physical_plan(&ctx, &q7_sql)); + b.iter(|| physical_plan(&ctx, &q8_sql)); + b.iter(|| physical_plan(&ctx, &q9_sql)); + b.iter(|| physical_plan(&ctx, &q10_sql)); + b.iter(|| physical_plan(&ctx, &q11_sql)); + b.iter(|| physical_plan(&ctx, &q12_sql)); + b.iter(|| physical_plan(&ctx, &q13_sql)); + b.iter(|| physical_plan(&ctx, &q14_sql)); + // b.iter(|| physical_plan(&ctx, &q15_sql)); + b.iter(|| physical_plan(&ctx, &q16_sql)); + b.iter(|| physical_plan(&ctx, &q17_sql)); + b.iter(|| physical_plan(&ctx, &q18_sql)); + b.iter(|| physical_plan(&ctx, &q19_sql)); + b.iter(|| physical_plan(&ctx, &q20_sql)); + b.iter(|| physical_plan(&ctx, &q21_sql)); + b.iter(|| physical_plan(&ctx, &q22_sql)); + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/benches/sql_query_with_io.rs b/datafusion/core/benches/sql_query_with_io.rs index 1d96df0cecaa..c7a838385bd6 100644 --- a/datafusion/core/benches/sql_query_with_io.rs +++ b/datafusion/core/benches/sql_query_with_io.rs @@ -93,10 +93,9 @@ async fn setup_files(store: Arc) { for partition in 0..TABLE_PARTITIONS { for file in 0..PARTITION_FILES { let data = create_parquet_file(&mut rng, file * FILE_ROWS); - let location = Path::try_from(format!( + let location = Path::from(format!( "{table_name}/partition={partition}/{file}.parquet" - )) - .unwrap(); + )); store.put(&location, data).await.unwrap(); } } @@ -120,7 +119,7 @@ async fn setup_context(object_store: Arc) -> SessionContext { let config = SessionConfig::new().with_target_partitions(THREADS); let rt = Arc::new(RuntimeEnv::default()); rt.register_object_store(&Url::parse("data://my_store").unwrap(), object_store); - let context = SessionContext::with_config_rt(config, rt); + let context = SessionContext::new_with_config_rt(config, rt); for table_id in 0..TABLES { let table_name = table_name(table_id); diff --git a/datafusion/core/benches/topk_aggregate.rs b/datafusion/core/benches/topk_aggregate.rs index f50a8ec047da..922cbd2b4229 100644 --- a/datafusion/core/benches/topk_aggregate.rs +++ b/datafusion/core/benches/topk_aggregate.rs @@ -15,19 +15,15 @@ // specific language governing permissions and limitations // under the License. +mod data_utils; use arrow::util::pretty::pretty_format_batches; -use arrow::{datatypes::Schema, record_batch::RecordBatch}; -use arrow_array::builder::{Int64Builder, StringBuilder}; -use arrow_schema::{DataType, Field, SchemaRef}; use criterion::{criterion_group, criterion_main, Criterion}; +use data_utils::make_data; use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; -use datafusion_common::DataFusionError; use datafusion_execution::config::SessionConfig; use datafusion_execution::TaskContext; -use rand_distr::Distribution; -use rand_distr::{Normal, Pareto}; use std::sync::Arc; use tokio::runtime::Runtime; @@ -45,7 +41,7 @@ async fn create_context( let mut cfg = SessionConfig::new(); let opts = cfg.options_mut(); opts.optimizer.enable_topk_aggregation = use_topk; - let ctx = SessionContext::with_config(cfg); + let ctx = SessionContext::new_with_config(cfg); let _ = ctx.register_table("traces", mem_table)?; let sql = format!("select trace_id, max(timestamp_ms) from traces group by trace_id order by max(timestamp_ms) desc limit {limit};"); let df = ctx.sql(sql.as_str()).await?; @@ -77,10 +73,10 @@ async fn aggregate( let batch = batches.first().unwrap(); assert_eq!(batch.num_rows(), 10); - let actual = format!("{}", pretty_format_batches(&batches)?); + let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase(); let expected_asc = r#" +----------------------------------+--------------------------+ -| trace_id | MAX(traces.timestamp_ms) | +| trace_id | max(traces.timestamp_ms) | +----------------------------------+--------------------------+ | 5868861a23ed31355efc5200eb80fe74 | 16909009999999 | | 4040e64656804c3d77320d7a0e7eb1f0 | 16909009999998 | @@ -102,83 +98,6 @@ async fn aggregate( Ok(()) } -fn make_data( - partition_cnt: i32, - sample_cnt: i32, - asc: bool, -) -> Result<(Arc, Vec>), DataFusionError> { - use rand::Rng; - use rand::SeedableRng; - - // constants observed from trace data - let simultaneous_group_cnt = 2000; - let fitted_shape = 12f64; - let fitted_scale = 5f64; - let mean = 0.1; - let stddev = 1.1; - let pareto = Pareto::new(fitted_scale, fitted_shape).unwrap(); - let normal = Normal::new(mean, stddev).unwrap(); - let mut rng = rand::rngs::SmallRng::from_seed([0; 32]); - - // populate data - let schema = test_schema(); - let mut partitions = vec![]; - let mut cur_time = 16909000000000i64; - for _ in 0..partition_cnt { - let mut id_builder = StringBuilder::new(); - let mut ts_builder = Int64Builder::new(); - let gen_id = |rng: &mut rand::rngs::SmallRng| { - rng.gen::<[u8; 16]>() - .iter() - .map(|b| format!("{:02x}", b)) - .collect::() - }; - let gen_sample_cnt = - |mut rng: &mut rand::rngs::SmallRng| pareto.sample(&mut rng).ceil() as u32; - let mut group_ids = (0..simultaneous_group_cnt) - .map(|_| gen_id(&mut rng)) - .collect::>(); - let mut group_sample_cnts = (0..simultaneous_group_cnt) - .map(|_| gen_sample_cnt(&mut rng)) - .collect::>(); - for _ in 0..sample_cnt { - let random_index = rng.gen_range(0..simultaneous_group_cnt); - let trace_id = &mut group_ids[random_index]; - let sample_cnt = &mut group_sample_cnts[random_index]; - *sample_cnt -= 1; - if *sample_cnt == 0 { - *trace_id = gen_id(&mut rng); - *sample_cnt = gen_sample_cnt(&mut rng); - } - - id_builder.append_value(trace_id); - ts_builder.append_value(cur_time); - - if asc { - cur_time += 1; - } else { - let samp: f64 = normal.sample(&mut rng); - let samp = samp.round(); - cur_time += samp as i64; - } - } - - // convert to MemTable - let id_col = Arc::new(id_builder.finish()); - let ts_col = Arc::new(ts_builder.finish()); - let batch = RecordBatch::try_new(schema.clone(), vec![id_col, ts_col])?; - partitions.push(vec![batch]); - } - Ok((schema, partitions)) -} - -fn test_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("trace_id", DataType::Utf8, false), - Field::new("timestamp_ms", DataType::Int64, false), - ])) -} - fn criterion_benchmark(c: &mut Criterion) { let limit = 10; let partitions = 10; diff --git a/datafusion/core/src/catalog/information_schema.rs b/datafusion/core/src/catalog/information_schema.rs index b30683a3ea13..3a8fef2d25ab 100644 --- a/datafusion/core/src/catalog/information_schema.rs +++ b/datafusion/core/src/catalog/information_schema.rs @@ -626,7 +626,8 @@ impl InformationSchemaDfSettings { fn new(config: InformationSchemaConfig) -> Self { let schema = Arc::new(Schema::new(vec![ Field::new("name", DataType::Utf8, false), - Field::new("setting", DataType::Utf8, true), + Field::new("value", DataType::Utf8, true), + Field::new("description", DataType::Utf8, true), ])); Self { schema, config } @@ -635,7 +636,8 @@ impl InformationSchemaDfSettings { fn builder(&self) -> InformationSchemaDfSettingsBuilder { InformationSchemaDfSettingsBuilder { names: StringBuilder::new(), - settings: StringBuilder::new(), + values: StringBuilder::new(), + descriptions: StringBuilder::new(), schema: self.schema.clone(), } } @@ -664,13 +666,15 @@ impl PartitionStream for InformationSchemaDfSettings { struct InformationSchemaDfSettingsBuilder { schema: SchemaRef, names: StringBuilder, - settings: StringBuilder, + values: StringBuilder, + descriptions: StringBuilder, } impl InformationSchemaDfSettingsBuilder { fn add_setting(&mut self, entry: ConfigEntry) { self.names.append_value(entry.key); - self.settings.append_option(entry.value); + self.values.append_option(entry.value); + self.descriptions.append_value(entry.description); } fn finish(&mut self) -> RecordBatch { @@ -678,7 +682,8 @@ impl InformationSchemaDfSettingsBuilder { self.schema.clone(), vec![ Arc::new(self.names.finish()), - Arc::new(self.settings.finish()), + Arc::new(self.values.finish()), + Arc::new(self.descriptions.finish()), ], ) .unwrap() diff --git a/datafusion/core/src/catalog/listing_schema.rs b/datafusion/core/src/catalog/listing_schema.rs index e7b4d8dec03c..c3c682689542 100644 --- a/datafusion/core/src/catalog/listing_schema.rs +++ b/datafusion/core/src/catalog/listing_schema.rs @@ -16,21 +16,25 @@ // under the License. //! listing_schema contains a SchemaProvider that scans ObjectStores for tables automatically + +use std::any::Any; +use std::collections::{HashMap, HashSet}; +use std::path::Path; +use std::sync::{Arc, Mutex}; + use crate::catalog::schema::SchemaProvider; use crate::datasource::provider::TableProviderFactory; use crate::datasource::TableProvider; use crate::execution::context::SessionState; -use async_trait::async_trait; + use datafusion_common::parsers::CompressionTypeVariant; -use datafusion_common::{DFSchema, DataFusionError, OwnedTableReference}; +use datafusion_common::{Constraints, DFSchema, DataFusionError, OwnedTableReference}; use datafusion_expr::CreateExternalTable; + +use async_trait::async_trait; use futures::TryStreamExt; use itertools::Itertools; use object_store::ObjectStore; -use std::any::Any; -use std::collections::{HashMap, HashSet}; -use std::path::Path; -use std::sync::{Arc, Mutex}; /// A [`SchemaProvider`] that scans an [`ObjectStore`] to automatically discover tables /// @@ -88,12 +92,7 @@ impl ListingSchemaProvider { /// Reload table information from ObjectStore pub async fn refresh(&self, state: &SessionState) -> datafusion_common::Result<()> { - let entries: Vec<_> = self - .store - .list(Some(&self.path)) - .await? - .try_collect() - .await?; + let entries: Vec<_> = self.store.list(Some(&self.path)).try_collect().await?; let base = Path::new(self.path.as_ref()); let mut tables = HashSet::new(); for file in entries.iter() { @@ -149,6 +148,8 @@ impl ListingSchemaProvider { order_exprs: vec![], unbounded: false, options: Default::default(), + constraints: Constraints::empty(), + column_defaults: Default::default(), }, ) .await?; diff --git a/datafusion/core/src/catalog/mod.rs b/datafusion/core/src/catalog/mod.rs index 6751edbd3a84..ce27d57da00d 100644 --- a/datafusion/core/src/catalog/mod.rs +++ b/datafusion/core/src/catalog/mod.rs @@ -31,7 +31,7 @@ use std::sync::Arc; /// Represent a list of named catalogs pub trait CatalogList: Sync + Send { - /// Returns the catalog list as [`Any`](std::any::Any) + /// Returns the catalog list as [`Any`] /// so that it can be downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -93,15 +93,9 @@ impl CatalogList for MemoryCatalogList { } } -impl Default for MemoryCatalogProvider { - fn default() -> Self { - Self::new() - } -} - /// Represents a catalog, comprising a number of named schemas. pub trait CatalogProvider: Sync + Send { - /// Returns the catalog provider as [`Any`](std::any::Any) + /// Returns the catalog provider as [`Any`] /// so that it can be downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -161,6 +155,12 @@ impl MemoryCatalogProvider { } } +impl Default for MemoryCatalogProvider { + fn default() -> Self { + Self::new() + } +} + impl CatalogProvider for MemoryCatalogProvider { fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe/mod.rs similarity index 79% rename from datafusion/core/src/dataframe.rs rename to datafusion/core/src/dataframe/mod.rs index 640f57f3d5fc..f15f1e9ba6fb 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -17,50 +17,49 @@ //! [`DataFrame`] API for building and executing query plans. +#[cfg(feature = "parquet")] +mod parquet; + use std::any::Any; use std::sync::Arc; +use crate::arrow::datatypes::{Schema, SchemaRef}; +use crate::arrow::record_batch::RecordBatch; +use crate::arrow::util::pretty; +use crate::datasource::{provider_as_source, MemTable, TableProvider}; +use crate::error::Result; +use crate::execution::{ + context::{SessionState, TaskContext}, + FunctionRegistry, +}; +use crate::logical_expr::utils::find_window_exprs; +use crate::logical_expr::{ + col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, TableType, +}; +use crate::physical_plan::{ + collect, collect_partitioned, execute_stream, execute_stream_partitioned, + ExecutionPlan, SendableRecordBatchStream, +}; +use crate::prelude::SessionContext; + use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; use arrow::csv::WriterBuilder; use arrow::datatypes::{DataType, Field}; -use async_trait::async_trait; use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; -use datafusion_common::file_options::parquet_writer::{ - default_builder, ParquetWriterOptions, -}; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - DataFusionError, FileType, FileTypeWriterOptions, SchemaError, UnnestOptions, + Column, DFSchema, DataFusionError, FileType, FileTypeWriterOptions, ParamValues, + SchemaError, UnnestOptions, }; use datafusion_expr::dml::CopyOptions; -use parquet::file::properties::WriterProperties; - -use datafusion_common::{Column, DFSchema, ScalarValue}; use datafusion_expr::{ avg, count, is_null, max, median, min, stddev, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; -use crate::arrow::datatypes::Schema; -use crate::arrow::datatypes::SchemaRef; -use crate::arrow::record_batch::RecordBatch; -use crate::arrow::util::pretty; -use crate::datasource::{provider_as_source, MemTable, TableProvider}; -use crate::error::Result; -use crate::execution::{ - context::{SessionState, TaskContext}, - FunctionRegistry, -}; -use crate::logical_expr::{ - col, utils::find_window_exprs, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, - Partitioning, TableType, -}; -use crate::physical_plan::SendableRecordBatchStream; -use crate::physical_plan::{collect, collect_partitioned}; -use crate::physical_plan::{execute_stream, execute_stream_partitioned, ExecutionPlan}; -use crate::prelude::SessionContext; +use async_trait::async_trait; /// Contains options that control how data is /// written out from a DataFrame @@ -582,12 +581,21 @@ impl DataFrame { Ok(DataFrame::new(self.session_state, plan)) } - /// Join this DataFrame with another DataFrame using the specified columns as join keys. + /// Join this `DataFrame` with another `DataFrame` using explicitly specified + /// columns and an optional filter expression. + /// + /// See [`join_on`](Self::join_on) for a more concise way to specify the + /// join condition. Since DataFusion will automatically identify and + /// optimize equality predicates there is no performance difference between + /// this function and `join_on` + /// + /// `left_cols` and `right_cols` are used to form "equijoin" predicates (see + /// example below), which are then combined with the optional `filter` + /// expression. /// - /// Filter expression expected to contain non-equality predicates that can not be pushed - /// down to any of join inputs. - /// In case of outer join, filter applied to only matched rows. + /// Note that in case of outer join, the `filter` is applied to only matched rows. /// + /// # Example /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; @@ -600,11 +608,14 @@ impl DataFrame { /// col("a").alias("a2"), /// col("b").alias("b2"), /// col("c").alias("c2")])?; + /// // Perform the equivalent of `left INNER JOIN right ON (a = a2 AND b = b2)` + /// // finding all pairs of rows from `left` and `right` where `a = a2` and `b = b2`. /// let join = left.join(right, JoinType::Inner, &["a", "b"], &["a2", "b2"], None)?; /// let batches = join.collect().await?; /// # Ok(()) /// # } /// ``` + /// pub fn join( self, right: DataFrame, @@ -624,10 +635,13 @@ impl DataFrame { Ok(DataFrame::new(self.session_state, plan)) } - /// Join this DataFrame with another DataFrame using the specified expressions. + /// Join this `DataFrame` with another `DataFrame` using the specified + /// expressions. + /// + /// Note that DataFusion automatically optimizes joins, including + /// identifying and optimizing equality predicates. /// - /// Simply a thin wrapper over [`join`](Self::join) where the join keys are not provided, - /// and the provided expressions are AND'ed together to form the filter expression. + /// # Example /// /// ``` /// # use datafusion::prelude::*; @@ -646,6 +660,10 @@ impl DataFrame { /// col("b").alias("b2"), /// col("c").alias("c2"), /// ])?; + /// + /// // Perform the equivalent of `left INNER JOIN right ON (a != a2 AND b != b2)` + /// // finding all pairs of rows from `left` and `right` where + /// // where `a != a2` and `b != b2`. /// let join_on = left.join_on( /// right, /// JoinType::Inner, @@ -663,12 +681,7 @@ impl DataFrame { ) -> Result { let expr = on_exprs.into_iter().reduce(Expr::and); let plan = LogicalPlanBuilder::from(self.plan) - .join( - right.plan, - join_type, - (Vec::::new(), Vec::::new()), - expr, - )? + .join_on(right.plan, join_type, expr)? .build()?; Ok(DataFrame::new(self.session_state, plan)) } @@ -789,6 +802,7 @@ impl DataFrame { /// Executes this DataFrame and returns a stream over a single partition /// + /// # Example /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; @@ -800,6 +814,11 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` + /// + /// # Aborting Execution + /// + /// Dropping the stream will abort the execution of the query, and free up + /// any allocated resources pub async fn execute_stream(self) -> Result { let task_ctx = Arc::new(self.task_ctx()); let plan = self.create_physical_plan().await?; @@ -828,6 +847,7 @@ impl DataFrame { /// Executes this DataFrame and returns one stream per partition. /// + /// # Example /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; @@ -839,6 +859,10 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` + /// # Aborting Execution + /// + /// Dropping the stream will abort the execution of the query, and free up + /// any allocated resources pub async fn execute_stream_partitioned( self, ) -> Result> { @@ -1000,11 +1024,16 @@ impl DataFrame { )) } - /// Write this DataFrame to the referenced table + /// Write this DataFrame to the referenced table by name. /// This method uses on the same underlying implementation - /// as the SQL Insert Into statement. - /// Unlike most other DataFrame methods, this method executes - /// eagerly, writing data, and returning the count of rows written. + /// as the SQL Insert Into statement. Unlike most other DataFrame methods, + /// this method executes eagerly. Data is written to the table using an + /// execution plan returned by the [TableProvider]'s insert_into method. + /// Refer to the documentation of the specific [TableProvider] to determine + /// the expected data returned by the insert_into plan via this method. + /// For the built in ListingTable provider, a single [RecordBatch] containing + /// a single column and row representing the count of total rows written + /// is returned. pub async fn write_table( self, table_name: &str, @@ -1053,40 +1082,6 @@ impl DataFrame { DataFrame::new(self.session_state, plan).collect().await } - /// Write a `DataFrame` to a Parquet file. - pub async fn write_parquet( - self, - path: &str, - options: DataFrameWriteOptions, - writer_properties: Option, - ) -> Result, DataFusionError> { - if options.overwrite { - return Err(DataFusionError::NotImplemented( - "Overwrites are not implemented for DataFrame::write_parquet.".to_owned(), - )); - } - match options.compression{ - CompressionTypeVariant::UNCOMPRESSED => (), - _ => return Err(DataFusionError::Configuration("DataFrame::write_parquet method does not support compression set via DataFrameWriteOptions. Set parquet compression via writer_properties instead.".to_owned())) - } - let props = match writer_properties { - Some(props) => props, - None => default_builder(self.session_state.config_options())?.build(), - }; - let file_type_writer_options = - FileTypeWriterOptions::Parquet(ParquetWriterOptions::new(props)); - let copy_options = CopyOptions::WriterOptions(Box::new(file_type_writer_options)); - let plan = LogicalPlanBuilder::copy_to( - self.plan, - path.into(), - FileType::PARQUET, - options.single_file_output, - copy_options, - )? - .build()?; - DataFrame::new(self.session_state, plan).collect().await - } - /// Executes a query and writes the results to a partitioned JSON file. pub async fn write_json( self, @@ -1144,10 +1139,7 @@ impl DataFrame { col_exists = true; new_column.clone() } else { - Expr::Column(Column { - relation: None, - name: f.name().into(), - }) + col(f.qualified_column()) } }) .collect(); @@ -1194,7 +1186,7 @@ impl DataFrame { let field_to_rename = match self.plan.schema().field_from_column(&old_column) { Ok(field) => field, // no-op if field not found - Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { .. })) => { + Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }, _)) => { return Ok(self) } Err(err) => return Err(err), @@ -1218,9 +1210,65 @@ impl DataFrame { Ok(DataFrame::new(self.session_state, project_plan)) } - /// Convert a prepare logical plan into its inner logical plan with all params replaced with their corresponding values - pub fn with_param_values(self, param_values: Vec) -> Result { - let plan = self.plan.with_param_values(param_values)?; + /// Replace all parameters in logical plan with the specified + /// values, in preparation for execution. + /// + /// # Example + /// + /// ``` + /// use datafusion::prelude::*; + /// # use datafusion::{error::Result, assert_batches_eq}; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// # use datafusion_common::ScalarValue; + /// let mut ctx = SessionContext::new(); + /// # ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?; + /// let results = ctx + /// .sql("SELECT a FROM example WHERE b = $1") + /// .await? + /// // replace $1 with value 2 + /// .with_param_values(vec![ + /// // value at index 0 --> $1 + /// ScalarValue::from(2i64) + /// ])? + /// .collect() + /// .await?; + /// assert_batches_eq!( + /// &[ + /// "+---+", + /// "| a |", + /// "+---+", + /// "| 1 |", + /// "+---+", + /// ], + /// &results + /// ); + /// // Note you can also provide named parameters + /// let results = ctx + /// .sql("SELECT a FROM example WHERE b = $my_param") + /// .await? + /// // replace $my_param with value 2 + /// // Note you can also use a HashMap as well + /// .with_param_values(vec![ + /// ("my_param", ScalarValue::from(2i64)) + /// ])? + /// .collect() + /// .await?; + /// assert_batches_eq!( + /// &[ + /// "+---+", + /// "| a |", + /// "+---+", + /// "| 1 |", + /// "+---+", + /// ], + /// &results + /// ); + /// # Ok(()) + /// # } + /// ``` + pub fn with_param_values(self, query_values: impl Into) -> Result { + let plan = self.plan.with_param_values(query_values)?; Ok(Self::new(self.session_state, plan)) } @@ -1238,12 +1286,13 @@ impl DataFrame { /// # } /// ``` pub async fn cache(self) -> Result { - let context = SessionContext::with_state(self.session_state.clone()); - let mem_table = MemTable::try_new( - SchemaRef::from(self.schema().clone()), - self.collect_partitioned().await?, - )?; - + let context = SessionContext::new_with_state(self.session_state.clone()); + // The schema is consistent with the output + let plan = self.clone().create_physical_plan().await?; + let schema = plan.schema(); + let task_ctx = Arc::new(self.task_ctx()); + let partitions = collect_partitioned(plan, task_ctx).await?; + let mem_table = MemTable::try_new(schema, partitions)?; context.read_table(Arc::new(mem_table)) } } @@ -1310,31 +1359,144 @@ impl TableProvider for DataFrameTableProvider { mod tests { use std::vec; - use arrow::array::Int32Array; - use arrow::datatypes::DataType; + use super::*; + use crate::execution::context::SessionConfig; + use crate::physical_plan::{ColumnarValue, Partitioning, PhysicalExpr}; + use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; + use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; + use arrow::array::{self, Int32Array}; + use arrow::datatypes::DataType; + use datafusion_common::{Constraint, Constraints}; use datafusion_expr::{ avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum, BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, - WindowFunction, + WindowFunctionDefinition, }; use datafusion_physical_expr::expressions::Column; - use object_store::local::LocalFileSystem; - use parquet::basic::{BrotliLevel, GzipLevel, ZstdLevel}; - use parquet::file::reader::FileReader; - use tempfile::TempDir; - use url::Url; + use datafusion_physical_plan::get_plan_string; - use crate::execution::context::SessionConfig; - use crate::execution::options::{CsvReadOptions, ParquetReadOptions}; - use crate::physical_plan::ColumnarValue; - use crate::physical_plan::Partitioning; - use crate::physical_plan::PhysicalExpr; - use crate::test_util; - use crate::test_util::parquet_test_data; - use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; + // Get string representation of the plan + async fn assert_physical_plan(df: &DataFrame, expected: Vec<&str>) { + let physical_plan = df + .clone() + .create_physical_plan() + .await + .expect("Error creating physical plan"); - use super::*; + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + } + + pub fn table_with_constraints() -> Arc { + let dual_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + dual_schema.clone(), + vec![ + Arc::new(array::Int32Array::from(vec![1])), + Arc::new(array::StringArray::from(vec!["a"])), + ], + ) + .unwrap(); + let provider = MemTable::try_new(dual_schema, vec![vec![batch]]) + .unwrap() + .with_constraints(Constraints::new_unverified(vec![Constraint::PrimaryKey( + vec![0], + )])); + Arc::new(provider) + } + + async fn assert_logical_expr_schema_eq_physical_expr_schema( + df: DataFrame, + ) -> Result<()> { + let logical_expr_dfschema = df.schema(); + let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned()); + let batches = df.collect().await?; + let physical_expr_schema = batches[0].schema(); + assert_eq!(logical_expr_schema, physical_expr_schema); + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_ord_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (3.0, 'c') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg("double_field" ORDER BY "string_field") as "double_field", + array_agg("string_field" ORDER BY "string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (3.0, 'c') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg("double_field") as "double_field", + array_agg("string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_distinct_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (2.0, 'a') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg(distinct "double_field") as "double_field", + array_agg(distinct "string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } #[tokio::test] async fn select_columns() -> Result<()> { @@ -1374,7 +1536,9 @@ mod tests { // build plan using Table API let t = test_table().await?; let first_row = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::FirstValue, + ), vec![col("aggregate_test_100.c1")], vec![col("aggregate_test_100.c2")], vec![], @@ -1445,6 +1609,223 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_with_pk() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let df = ctx.read_table(table_with_constraints())?; + + // GROUP BY id + let group_expr = vec![col("id")]; + let aggr_expr = vec![]; + let df = df.aggregate(group_expr, aggr_expr)?; + + // Since id and name are functionally dependant, we can use name among + // expression even if it is not part of the group by expression and can + // select "name" column even though it wasn't explicitly grouped + let df = df.select(vec![col("id"), col("name")])?; + assert_physical_plan( + &df, + vec![ + "AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ], + ) + .await; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!([ + "+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+" + ], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk2() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let df = ctx.read_table(table_with_constraints())?; + + // GROUP BY id + let group_expr = vec![col("id")]; + let aggr_expr = vec![]; + let df = df.aggregate(group_expr, aggr_expr)?; + + // Predicate refers to id, and name fields: + // id = 1 AND name = 'a' + let predicate = col("id").eq(lit(1i32)).and(col("name").eq(lit("a"))); + let df = df.filter(predicate)?; + assert_physical_plan( + &df, + vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1 AND name@1 = a", + " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ], + ) + .await; + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk3() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let df = ctx.read_table(table_with_constraints())?; + + // GROUP BY id + let group_expr = vec![col("id")]; + let aggr_expr = vec![]; + // group by id, + let df = df.aggregate(group_expr, aggr_expr)?; + + // Predicate refers to id field + // id = 1 + let predicate = col("id").eq(lit(1i32)); + let df = df.filter(predicate)?; + // Select expression refers to id, and name columns. + // id, name + let df = df.select(vec![col("id"), col("name")])?; + assert_physical_plan( + &df, + vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1", + " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ], + ) + .await; + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk4() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let df = ctx.read_table(table_with_constraints())?; + + // GROUP BY id + let group_expr = vec![col("id")]; + let aggr_expr = vec![]; + let df = df.aggregate(group_expr, aggr_expr)?; + + // Predicate refers to id field + // id = 1 + let predicate = col("id").eq(lit(1i32)); + let df = df.filter(predicate)?; + // Select expression refers to id column. + // id + let df = df.select(vec![col("id")])?; + + // In this case aggregate shouldn't be expanded, since these + // columns are not used. + assert_physical_plan( + &df, + vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1", + " AggregateExec: mode=Single, gby=[id@0 as id], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ], + ) + .await; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!([ + "+----+", + "| id |", + "+----+", + "| 1 |", + "+----+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_alias() -> Result<()> { + let df = test_table().await?; + + let df = df + // GROUP BY `c2 + 1` + .aggregate(vec![col("c2") + lit(1)], vec![])? + // SELECT `c2 + 1` as c2 + .select(vec![(col("c2") + lit(1)).alias("c2")])? + // GROUP BY c2 as "c2" (alias in expr is not supported by SQL) + .aggregate(vec![col("c2").alias("c2")], vec![])?; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!([ + "+----+", + "| c2 |", + "+----+", + "| 2 |", + "| 3 |", + "| 4 |", + "| 5 |", + "| 6 |", + "+----+", + ], + &df_results + ); + + Ok(()) + } + #[tokio::test] async fn test_distinct() -> Result<()> { let t = test_table().await?; @@ -1752,31 +2133,6 @@ mod tests { Ok(ctx.sql(sql).await?.into_unoptimized_plan()) } - async fn test_table_with_name(name: &str) -> Result { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx, name).await?; - ctx.table(name).await - } - - async fn test_table() -> Result { - test_table_with_name("aggregate_test_100").await - } - - async fn register_aggregate_csv( - ctx: &mut SessionContext, - table_name: &str, - ) -> Result<()> { - let schema = test_util::aggr_test_schema(); - let testdata = test_util::arrow_test_data(); - ctx.register_csv( - table_name, - &format!("{testdata}/csv/aggregate_test_100.csv"), - CsvReadOptions::new().schema(schema.as_ref()), - ) - .await?; - Ok(()) - } - #[tokio::test] async fn with_column() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?; @@ -1855,6 +2211,131 @@ mod tests { Ok(()) } + // Test issue: https://github.com/apache/arrow-datafusion/issues/7790 + // The join operation outputs two identical column names, but they belong to different relations. + #[tokio::test] + async fn with_column_join_same_columns() -> Result<()> { + let df = test_table().await?.select_columns(&["c1"])?; + let ctx = SessionContext::new(); + + let table = df.into_view(); + ctx.register_table("t1", table.clone())?; + ctx.register_table("t2", table)?; + let df = ctx + .table("t1") + .await? + .join( + ctx.table("t2").await?, + JoinType::Inner, + &["c1"], + &["c1"], + None, + )? + .sort(vec![ + // make the test deterministic + col("t1.c1").sort(true, true), + ])? + .limit(0, Some(1))?; + + let df_results = df.clone().collect().await?; + assert_batches_sorted_eq!( + [ + "+----+----+", + "| c1 | c1 |", + "+----+----+", + "| a | a |", + "+----+----+", + ], + &df_results + ); + + let df_with_column = df.clone().with_column("new_column", lit(true))?; + + assert_eq!( + "\ + Projection: t1.c1, t2.c1, Boolean(true) AS new_column\ + \n Limit: skip=0, fetch=1\ + \n Sort: t1.c1 ASC NULLS FIRST\ + \n Inner Join: t1.c1 = t2.c1\ + \n TableScan: t1\ + \n TableScan: t2", + format!("{:?}", df_with_column.logical_plan()) + ); + + assert_eq!( + "\ + Projection: t1.c1, t2.c1, Boolean(true) AS new_column\ + \n Limit: skip=0, fetch=1\ + \n Sort: t1.c1 ASC NULLS FIRST, fetch=1\ + \n Inner Join: t1.c1 = t2.c1\ + \n SubqueryAlias: t1\ + \n TableScan: aggregate_test_100 projection=[c1]\ + \n SubqueryAlias: t2\ + \n TableScan: aggregate_test_100 projection=[c1]", + format!("{:?}", df_with_column.clone().into_optimized_plan()?) + ); + + let df_results = df_with_column.collect().await?; + + assert_batches_sorted_eq!( + [ + "+----+----+------------+", + "| c1 | c1 | new_column |", + "+----+----+------------+", + "| a | a | true |", + "+----+----+------------+", + ], + &df_results + ); + Ok(()) + } + + // Table 't1' self join + // Supplementary test of issue: https://github.com/apache/arrow-datafusion/issues/7790 + #[tokio::test] + async fn with_column_self_join() -> Result<()> { + let df = test_table().await?.select_columns(&["c1"])?; + let ctx = SessionContext::new(); + + ctx.register_table("t1", df.into_view())?; + + let df = ctx + .table("t1") + .await? + .join( + ctx.table("t1").await?, + JoinType::Inner, + &["c1"], + &["c1"], + None, + )? + .sort(vec![ + // make the test deterministic + col("t1.c1").sort(true, true), + ])? + .limit(0, Some(1))?; + + let df_results = df.clone().collect().await?; + assert_batches_sorted_eq!( + [ + "+----+----+", + "| c1 | c1 |", + "+----+----+", + "| a | a |", + "+----+----+", + ], + &df_results + ); + + let actual_err = df.clone().with_column("new_column", lit(true)).unwrap_err(); + let expected_err = "Error during planning: Projections require unique expression names \ + but the expression \"t1.c1\" at position 0 and \"t1.c1\" at position 1 have the same name. \ + Consider aliasing (\"AS\") one of them."; + assert_eq!(actual_err.strip_backtrace(), expected_err); + + Ok(()) + } + #[tokio::test] async fn with_column_renamed() -> Result<()> { let df = test_table() @@ -2011,7 +2492,7 @@ mod tests { "datafusion.sql_parser.enable_ident_normalization".to_owned(), "false".to_owned(), )]))?; - let mut ctx = SessionContext::with_config(config); + let mut ctx = SessionContext::new_with_config(config); let name = "aggregate_test_100"; register_aggregate_csv(&mut ctx, name).await?; let df = ctx.table(name); @@ -2056,33 +2537,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn filter_pushdown_dataframe() -> Result<()> { - let ctx = SessionContext::new(); - - ctx.register_parquet( - "test", - &format!("{}/alltypes_plain.snappy.parquet", parquet_test_data()), - ParquetReadOptions::default(), - ) - .await?; - - ctx.register_table("t1", ctx.table("test").await?.into_view())?; - - let df = ctx - .table("t1") - .await? - .filter(col("id").eq(lit(1)))? - .select_columns(&["bool_col", "int_col"])?; - - let plan = df.explain(false, false)?.collect().await?; - // Filters all the way to Parquet - let formatted = pretty::pretty_format_batches(&plan)?.to_string(); - assert!(formatted.contains("FilterExec: id@0 = 1")); - - Ok(()) - } - #[tokio::test] async fn cast_expr_test() -> Result<()> { let df = test_table() @@ -2174,6 +2628,17 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_cache_mismatch() -> Result<()> { + let ctx = SessionContext::new(); + let df = ctx + .sql("SELECT CASE WHEN true THEN NULL ELSE 1 END") + .await?; + let cache_df = df.cache().await; + assert!(cache_df.is_ok()); + Ok(()) + } + #[tokio::test] async fn cache_test() -> Result<()> { let df = test_table() @@ -2367,53 +2832,4 @@ mod tests { Ok(()) } - - #[tokio::test] - async fn write_parquet_with_compression() -> Result<()> { - let test_df = test_table().await?; - - let output_path = "file://local/test.parquet"; - let test_compressions = vec![ - parquet::basic::Compression::SNAPPY, - parquet::basic::Compression::LZ4, - parquet::basic::Compression::LZ4_RAW, - parquet::basic::Compression::GZIP(GzipLevel::default()), - parquet::basic::Compression::BROTLI(BrotliLevel::default()), - parquet::basic::Compression::ZSTD(ZstdLevel::default()), - ]; - for compression in test_compressions.into_iter() { - let df = test_df.clone(); - let tmp_dir = TempDir::new()?; - let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); - let local_url = Url::parse("file://local").unwrap(); - let ctx = &test_df.session_state; - ctx.runtime_env().register_object_store(&local_url, local); - df.write_parquet( - output_path, - DataFrameWriteOptions::new().with_single_file_output(true), - Some( - WriterProperties::builder() - .set_compression(compression) - .build(), - ), - ) - .await?; - - // Check that file actually used the specified compression - let file = std::fs::File::open(tmp_dir.into_path().join("test.parquet"))?; - - let reader = - parquet::file::serialized_reader::SerializedFileReader::new(file) - .unwrap(); - - let parquet_metadata = reader.metadata(); - - let written_compression = - parquet_metadata.row_group(0).column(0).compression(); - - assert_eq!(written_compression, compression); - } - - Ok(()) - } } diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs new file mode 100644 index 000000000000..36ef90c987e3 --- /dev/null +++ b/datafusion/core/src/dataframe/parquet.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::file_options::parquet_writer::{ + default_builder, ParquetWriterOptions, +}; +use parquet::file::properties::WriterProperties; + +use super::{ + CompressionTypeVariant, CopyOptions, DataFrame, DataFrameWriteOptions, + DataFusionError, FileType, FileTypeWriterOptions, LogicalPlanBuilder, RecordBatch, +}; + +impl DataFrame { + /// Write a `DataFrame` to a Parquet file. + pub async fn write_parquet( + self, + path: &str, + options: DataFrameWriteOptions, + writer_properties: Option, + ) -> Result, DataFusionError> { + if options.overwrite { + return Err(DataFusionError::NotImplemented( + "Overwrites are not implemented for DataFrame::write_parquet.".to_owned(), + )); + } + match options.compression{ + CompressionTypeVariant::UNCOMPRESSED => (), + _ => return Err(DataFusionError::Configuration("DataFrame::write_parquet method does not support compression set via DataFrameWriteOptions. Set parquet compression via writer_properties instead.".to_owned())) + } + let props = match writer_properties { + Some(props) => props, + None => default_builder(self.session_state.config_options())?.build(), + }; + let file_type_writer_options = + FileTypeWriterOptions::Parquet(ParquetWriterOptions::new(props)); + let copy_options = CopyOptions::WriterOptions(Box::new(file_type_writer_options)); + let plan = LogicalPlanBuilder::copy_to( + self.plan, + path.into(), + FileType::PARQUET, + options.single_file_output, + copy_options, + )? + .build()?; + DataFrame::new(self.session_state, plan).collect().await + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use object_store::local::LocalFileSystem; + use parquet::basic::{BrotliLevel, GzipLevel, ZstdLevel}; + use parquet::file::reader::FileReader; + use tempfile::TempDir; + use url::Url; + + use datafusion_expr::{col, lit}; + + use crate::arrow::util::pretty; + use crate::execution::context::SessionContext; + use crate::execution::options::ParquetReadOptions; + use crate::test_util; + + use super::super::Result; + use super::*; + + #[tokio::test] + async fn filter_pushdown_dataframe() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_parquet( + "test", + &format!( + "{}/alltypes_plain.snappy.parquet", + test_util::parquet_test_data() + ), + ParquetReadOptions::default(), + ) + .await?; + + ctx.register_table("t1", ctx.table("test").await?.into_view())?; + + let df = ctx + .table("t1") + .await? + .filter(col("id").eq(lit(1)))? + .select_columns(&["bool_col", "int_col"])?; + + let plan = df.explain(false, false)?.collect().await?; + // Filters all the way to Parquet + let formatted = pretty::pretty_format_batches(&plan)?.to_string(); + assert!(formatted.contains("FilterExec: id@0 = 1")); + + Ok(()) + } + + #[tokio::test] + async fn write_parquet_with_compression() -> Result<()> { + let test_df = test_util::test_table().await?; + + let output_path = "file://local/test.parquet"; + let test_compressions = vec![ + parquet::basic::Compression::SNAPPY, + parquet::basic::Compression::LZ4, + parquet::basic::Compression::LZ4_RAW, + parquet::basic::Compression::GZIP(GzipLevel::default()), + parquet::basic::Compression::BROTLI(BrotliLevel::default()), + parquet::basic::Compression::ZSTD(ZstdLevel::default()), + ]; + for compression in test_compressions.into_iter() { + let df = test_df.clone(); + let tmp_dir = TempDir::new()?; + let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); + let local_url = Url::parse("file://local").unwrap(); + let ctx = &test_df.session_state; + ctx.runtime_env().register_object_store(&local_url, local); + df.write_parquet( + output_path, + DataFrameWriteOptions::new().with_single_file_output(true), + Some( + WriterProperties::builder() + .set_compression(compression) + .build(), + ), + ) + .await?; + + // Check that file actually used the specified compression + let file = std::fs::File::open(tmp_dir.into_path().join("test.parquet"))?; + + let reader = + parquet::file::serialized_reader::SerializedFileReader::new(file) + .unwrap(); + + let parquet_metadata = reader.metadata(); + + let written_compression = + parquet_metadata.row_group(0).column(0).compression(); + + assert_eq!(written_compression, compression); + } + + Ok(()) + } +} diff --git a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs index f983e26d48a4..a16c1ae3333f 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs @@ -45,6 +45,7 @@ use arrow::array::{BinaryArray, FixedSizeBinaryArray, GenericListArray}; use arrow::datatypes::{Fields, SchemaRef}; use arrow::error::ArrowError::SchemaError; use arrow::error::Result as ArrowResult; +use datafusion_common::arrow_err; use num_traits::NumCast; use std::collections::BTreeMap; use std::io::Read; @@ -82,38 +83,62 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { fields, mut lookup, .. }) => { for field in fields { - Self::child_schema_lookup(&field.schema, &mut lookup)?; + Self::child_schema_lookup(&field.name, &field.schema, &mut lookup)?; } Ok(lookup) } - _ => Err(DataFusionError::ArrowError(SchemaError( + _ => arrow_err!(SchemaError( "expected avro schema to be a record".to_string(), - ))), + )), } } fn child_schema_lookup<'b>( + parent_field_name: &str, schema: &AvroSchema, schema_lookup: &'b mut BTreeMap, ) -> Result<&'b BTreeMap> { match schema { - AvroSchema::Record(RecordSchema { - name, - fields, - lookup, - .. - }) => { + AvroSchema::Union(us) => { + let has_nullable = us + .find_schema_with_known_schemata::( + &Value::Null, + None, + &None, + ) + .is_some(); + let sub_schemas = us.variants(); + if has_nullable && sub_schemas.len() == 2 { + if let Some(sub_schema) = + sub_schemas.iter().find(|&s| !matches!(s, AvroSchema::Null)) + { + Self::child_schema_lookup( + parent_field_name, + sub_schema, + schema_lookup, + )?; + } + } + } + AvroSchema::Record(RecordSchema { fields, lookup, .. }) => { lookup.iter().for_each(|(field_name, pos)| { schema_lookup - .insert(format!("{}.{}", name.fullname(None), field_name), *pos); + .insert(format!("{}.{}", parent_field_name, field_name), *pos); }); for field in fields { - Self::child_schema_lookup(&field.schema, schema_lookup)?; + let sub_parent_field_name = + format!("{}.{}", parent_field_name, field.name); + Self::child_schema_lookup( + &sub_parent_field_name, + &field.schema, + schema_lookup, + )?; } } AvroSchema::Array(schema) => { - Self::child_schema_lookup(schema, schema_lookup)?; + let sub_parent_field_name = format!("{}.element", parent_field_name); + Self::child_schema_lookup(&sub_parent_field_name, schema, schema_lookup)?; } _ => (), } @@ -147,7 +172,8 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { let rows = rows.iter().collect::>>(); let projection = self.projection.clone().unwrap_or_default(); - let arrays = self.build_struct_array(&rows, self.schema.fields(), &projection); + let arrays = + self.build_struct_array(&rows, "", self.schema.fields(), &projection); let projected_fields = if projection.is_empty() { self.schema.fields().clone() } else { @@ -305,6 +331,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { for row in rows { if let Some(value) = self.field_lookup(col_name, row) { + let value = maybe_resolve_union(value); // value can be an array or a scalar let vals: Vec> = if let Value::String(v) = value { vec![Some(v.to_string())] @@ -444,6 +471,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { /// Build a nested GenericListArray from a list of unnested `Value`s fn build_nested_list_array( &self, + parent_field_name: &str, rows: &[&Value], list_field: &Field, ) -> ArrowResult { @@ -530,13 +558,19 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { .collect::() .into_data(), DataType::List(field) => { - let child = - self.build_nested_list_array::(&flatten_values(rows), field)?; + let child = self.build_nested_list_array::( + parent_field_name, + &flatten_values(rows), + field, + )?; child.to_data() } DataType::LargeList(field) => { - let child = - self.build_nested_list_array::(&flatten_values(rows), field)?; + let child = self.build_nested_list_array::( + parent_field_name, + &flatten_values(rows), + field, + )?; child.to_data() } DataType::Struct(fields) => { @@ -554,16 +588,22 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { let null_struct_array = vec![("null".to_string(), Value::Null)]; let rows: Vec<&Vec<(String, Value)>> = rows .iter() + .map(|v| maybe_resolve_union(v)) .flat_map(|row| { if let Value::Array(values) = row { - values.iter().for_each(|_| { - bit_util::set_bit(&mut null_buffer, struct_index); - struct_index += 1; - }); values .iter() + .map(maybe_resolve_union) .map(|v| match v { - Value::Record(record) => record, + Value::Record(record) => { + bit_util::set_bit(&mut null_buffer, struct_index); + struct_index += 1; + record + } + Value::Null => { + struct_index += 1; + &null_struct_array + } other => panic!("expected Record, got {other:?}"), }) .collect::>>() @@ -573,7 +613,11 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { } }) .collect(); - let arrays = self.build_struct_array(&rows, fields, &[])?; + + let sub_parent_field_name = + format!("{}.{}", parent_field_name, list_field.name()); + let arrays = + self.build_struct_array(&rows, &sub_parent_field_name, fields, &[])?; let data_type = DataType::Struct(fields.clone()); ArrayDataBuilder::new(data_type) .len(rows.len()) @@ -610,6 +654,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { fn build_struct_array( &self, rows: RecordSlice, + parent_field_name: &str, struct_fields: &Fields, projection: &[String], ) -> ArrowResult> { @@ -617,78 +662,83 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { .iter() .filter(|field| projection.is_empty() || projection.contains(field.name())) .map(|field| { + let field_path = if parent_field_name.is_empty() { + field.name().to_string() + } else { + format!("{}.{}", parent_field_name, field.name()) + }; let arr = match field.data_type() { DataType::Null => Arc::new(NullArray::new(rows.len())) as ArrayRef, - DataType::Boolean => self.build_boolean_array(rows, field.name()), + DataType::Boolean => self.build_boolean_array(rows, &field_path), DataType::Float64 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::Float32 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::Int64 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::Int32 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::Int16 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::Int8 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::UInt64 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::UInt32 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::UInt16 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::UInt8 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } // TODO: this is incomplete DataType::Timestamp(unit, _) => match unit { TimeUnit::Second => self .build_primitive_array::( rows, - field.name(), + &field_path, ), TimeUnit::Microsecond => self .build_primitive_array::( rows, - field.name(), + &field_path, ), TimeUnit::Millisecond => self .build_primitive_array::( rows, - field.name(), + &field_path, ), TimeUnit::Nanosecond => self .build_primitive_array::( rows, - field.name(), + &field_path, ), }, DataType::Date64 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::Date32 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::Time64(unit) => match unit { TimeUnit::Microsecond => self .build_primitive_array::( rows, - field.name(), + &field_path, ), TimeUnit::Nanosecond => self .build_primitive_array::( rows, - field.name(), + &field_path, ), t => { return Err(ArrowError::SchemaError(format!( @@ -698,14 +748,11 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { }, DataType::Time32(unit) => match unit { TimeUnit::Second => self - .build_primitive_array::( - rows, - field.name(), - ), + .build_primitive_array::(rows, &field_path), TimeUnit::Millisecond => self .build_primitive_array::( rows, - field.name(), + &field_path, ), t => { return Err(ArrowError::SchemaError(format!( @@ -716,7 +763,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::Utf8 | DataType::LargeUtf8 => Arc::new( rows.iter() .map(|row| { - let maybe_value = self.field_lookup(field.name(), row); + let maybe_value = self.field_lookup(&field_path, row); match maybe_value { None => Ok(None), Some(v) => resolve_string(v), @@ -728,7 +775,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::Binary | DataType::LargeBinary => Arc::new( rows.iter() .map(|row| { - let maybe_value = self.field_lookup(field.name(), row); + let maybe_value = self.field_lookup(&field_path, row); maybe_value.and_then(resolve_bytes) }) .collect::(), @@ -737,7 +784,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::FixedSizeBinary(ref size) => { Arc::new(FixedSizeBinaryArray::try_from_sparse_iter_with_size( rows.iter().map(|row| { - let maybe_value = self.field_lookup(field.name(), row); + let maybe_value = self.field_lookup(&field_path, row); maybe_value.and_then(|v| resolve_fixed(v, *size as usize)) }), *size, @@ -746,18 +793,19 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::List(ref list_field) => { match list_field.data_type() { DataType::Dictionary(ref key_ty, _) => { - self.build_wrapped_list_array(rows, field.name(), key_ty)? + self.build_wrapped_list_array(rows, &field_path, key_ty)? } _ => { // extract rows by name let extracted_rows = rows .iter() .map(|row| { - self.field_lookup(field.name(), row) + self.field_lookup(&field_path, row) .unwrap_or(&Value::Null) }) .collect::>(); self.build_nested_list_array::( + &field_path, &extracted_rows, list_field, )? @@ -767,7 +815,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::Dictionary(ref key_ty, ref val_ty) => self .build_string_dictionary_array( rows, - field.name(), + &field_path, key_ty, val_ty, )?, @@ -775,21 +823,31 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { let len = rows.len(); let num_bytes = bit_util::ceil(len, 8); let mut null_buffer = MutableBuffer::from_len_zeroed(num_bytes); + let empty_vec = vec![]; let struct_rows = rows .iter() .enumerate() - .map(|(i, row)| (i, self.field_lookup(field.name(), row))) + .map(|(i, row)| (i, self.field_lookup(&field_path, row))) .map(|(i, v)| { - if let Some(Value::Record(value)) = v { - bit_util::set_bit(&mut null_buffer, i); - value - } else { - panic!("expected struct got {v:?}"); + let v = v.map(maybe_resolve_union); + match v { + Some(Value::Record(value)) => { + bit_util::set_bit(&mut null_buffer, i); + value + } + None | Some(Value::Null) => &empty_vec, + other => { + panic!("expected struct got {other:?}"); + } } }) .collect::>>(); - let arrays = - self.build_struct_array(&struct_rows, fields, &[])?; + let arrays = self.build_struct_array( + &struct_rows, + &field_path, + fields, + &[], + )?; // construct a struct array's data in order to set null buffer let data_type = DataType::Struct(fields.clone()); let data = ArrayDataBuilder::new(data_type) @@ -1019,6 +1077,7 @@ mod test { use crate::arrow::datatypes::{Field, TimeUnit}; use crate::datasource::avro_to_arrow::{Reader, ReaderBuilder}; use arrow::datatypes::DataType; + use datafusion_common::assert_batches_eq; use datafusion_common::cast::{ as_int32_array, as_int64_array, as_list_array, as_timestamp_microsecond_array, }; @@ -1079,7 +1138,7 @@ mod test { let a_array = as_list_array(batch.column(col_id_index)).unwrap(); assert_eq!( *a_array.data_type(), - DataType::List(Arc::new(Field::new("bigint", DataType::Int64, true))) + DataType::List(Arc::new(Field::new("element", DataType::Int64, true))) ); let array = a_array.value(0); assert_eq!(*array.data_type(), DataType::Int64); @@ -1101,6 +1160,493 @@ mod test { assert_eq!(batch.num_rows(), 3); } + #[test] + fn test_complex_list() { + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "r1", + "fields": [ + { + "name": "headers", + "type": ["null", { + "type": "array", + "items": ["null",{ + "name":"r2", + "type": "record", + "fields":[ + {"name":"name", "type": ["null", "string"], "default": null}, + {"name":"value", "type": ["null", "string"], "default": null} + ] + }] + }], + "default": null + } + ] + }"#, + ) + .unwrap(); + let r1 = apache_avro::to_value(serde_json::json!({ + "headers": [ + { + "name": "a", + "value": "b" + } + ] + })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + w.append(r1).unwrap(); + let bytes = w.into_inner().unwrap(); + + let mut reader = ReaderBuilder::new() + .read_schema() + .with_batch_size(2) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 1); + assert_eq!(batch.num_columns(), 1); + let expected = [ + "+-----------------------+", + "| headers |", + "+-----------------------+", + "| [{name: a, value: b}] |", + "+-----------------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + + #[test] + fn test_complex_struct() { + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "r1", + "fields": [ + { + "name": "dns", + "type": [ + "null", + { + "type": "record", + "name": "r13", + "fields": [ + { + "name": "answers", + "type": [ + "null", + { + "type": "array", + "items": [ + "null", + { + "type": "record", + "name": "r292", + "fields": [ + { + "name": "class", + "type": ["null", "string"], + "default": null + }, + { + "name": "data", + "type": ["null", "string"], + "default": null + }, + { + "name": "name", + "type": ["null", "string"], + "default": null + }, + { + "name": "ttl", + "type": ["null", "long"], + "default": null + }, + { + "name": "type", + "type": ["null", "string"], + "default": null + } + ] + } + ] + } + ], + "default": null + }, + { + "name": "header_flags", + "type": [ + "null", + { + "type": "array", + "items": ["null", "string"] + } + ], + "default": null + }, + { + "name": "id", + "type": ["null", "string"], + "default": null + }, + { + "name": "op_code", + "type": ["null", "string"], + "default": null + }, + { + "name": "question", + "type": [ + "null", + { + "type": "record", + "name": "r288", + "fields": [ + { + "name": "class", + "type": ["null", "string"], + "default": null + }, + { + "name": "name", + "type": ["null", "string"], + "default": null + }, + { + "name": "registered_domain", + "type": ["null", "string"], + "default": null + }, + { + "name": "subdomain", + "type": ["null", "string"], + "default": null + }, + { + "name": "top_level_domain", + "type": ["null", "string"], + "default": null + }, + { + "name": "type", + "type": ["null", "string"], + "default": null + } + ] + } + ], + "default": null + }, + { + "name": "resolved_ip", + "type": [ + "null", + { + "type": "array", + "items": ["null", "string"] + } + ], + "default": null + }, + { + "name": "response_code", + "type": ["null", "string"], + "default": null + }, + { + "name": "type", + "type": ["null", "string"], + "default": null + } + ] + } + ], + "default": null + } + ] + }"#, + ) + .unwrap(); + + let jv1 = serde_json::json!({ + "dns": { + "answers": [ + { + "data": "CHNlY3VyaXR5BnVidW50dQMjb20AAAEAAQAAAAgABLl9vic=", + "type": "1" + }, + { + "data": "CHNlY3VyaXR5BnVidW50dQNjb20AAAEAABAAAAgABLl9viQ=", + "type": "1" + }, + { + "data": "CHNlT3VyaXR5BnVidW50dQNjb20AAAEAAQAAAAgABFu9Wyc=", + "type": "1" + } + ], + "question": { + "name": "security.ubuntu.com", + "type": "A" + }, + "resolved_ip": [ + "67.43.156.1", + "67.43.156.2", + "67.43.156.3" + ], + "response_code": "0" + } + }); + let r1 = apache_avro::to_value(jv1) + .unwrap() + .resolve(&schema) + .unwrap(); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + w.append(r1).unwrap(); + let bytes = w.into_inner().unwrap(); + + let mut reader = ReaderBuilder::new() + .read_schema() + .with_batch_size(1) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 1); + assert_eq!(batch.num_columns(), 1); + + let expected = [ + "+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + "| dns |", + "+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + "| {answers: [{class: , data: CHNlY3VyaXR5BnVidW50dQMjb20AAAEAAQAAAAgABLl9vic=, name: , ttl: , type: 1}, {class: , data: CHNlY3VyaXR5BnVidW50dQNjb20AAAEAABAAAAgABLl9viQ=, name: , ttl: , type: 1}, {class: , data: CHNlT3VyaXR5BnVidW50dQNjb20AAAEAAQAAAAgABFu9Wyc=, name: , ttl: , type: 1}], header_flags: , id: , op_code: , question: {class: , name: security.ubuntu.com, registered_domain: , subdomain: , top_level_domain: , type: A}, resolved_ip: [67.43.156.1, 67.43.156.2, 67.43.156.3], response_code: 0, type: } |", + "+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + + #[test] + fn test_deep_nullable_struct() { + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "r1", + "fields": [ + { + "name": "col1", + "type": [ + "null", + { + "type": "record", + "name": "r2", + "fields": [ + { + "name": "col2", + "type": [ + "null", + { + "type": "record", + "name": "r3", + "fields": [ + { + "name": "col3", + "type": [ + "null", + { + "type": "record", + "name": "r4", + "fields": [ + { + "name": "col4", + "type": [ + "null", + { + "type": "record", + "name": "r5", + "fields": [ + { + "name": "col5", + "type": ["null", "string"] + } + ] + } + ] + } + ] + } + ] + } + ] + } + ] + } + ] + } + ] + } + ] + } + "#, + ) + .unwrap(); + let r1 = apache_avro::to_value(serde_json::json!({ + "col1": { + "col2": { + "col3": { + "col4": { + "col5": "hello" + } + } + } + } + })) + .unwrap() + .resolve(&schema) + .unwrap(); + let r2 = apache_avro::to_value(serde_json::json!({ + "col1": { + "col2": { + "col3": { + "col4": { + "col5": null + } + } + } + } + })) + .unwrap() + .resolve(&schema) + .unwrap(); + let r3 = apache_avro::to_value(serde_json::json!({ + "col1": { + "col2": { + "col3": null + } + } + })) + .unwrap() + .resolve(&schema) + .unwrap(); + let r4 = apache_avro::to_value(serde_json::json!({ "col1": null })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + w.append(r1).unwrap(); + w.append(r2).unwrap(); + w.append(r3).unwrap(); + w.append(r4).unwrap(); + let bytes = w.into_inner().unwrap(); + + let mut reader = ReaderBuilder::new() + .read_schema() + .with_batch_size(4) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + + let expected = [ + "+---------------------------------------+", + "| col1 |", + "+---------------------------------------+", + "| {col2: {col3: {col4: {col5: hello}}}} |", + "| {col2: {col3: {col4: {col5: }}}} |", + "| {col2: {col3: }} |", + "| |", + "+---------------------------------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + + #[test] + fn test_avro_nullable_struct() { + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "r1", + "fields": [ + { + "name": "col1", + "type": [ + "null", + { + "type": "record", + "name": "r2", + "fields": [ + { + "name": "col2", + "type": ["null", "string"] + } + ] + } + ], + "default": null + } + ] + }"#, + ) + .unwrap(); + let r1 = apache_avro::to_value(serde_json::json!({ "col1": null })) + .unwrap() + .resolve(&schema) + .unwrap(); + let r2 = apache_avro::to_value(serde_json::json!({ + "col1": { + "col2": "hello" + } + })) + .unwrap() + .resolve(&schema) + .unwrap(); + let r3 = apache_avro::to_value(serde_json::json!({ + "col1": { + "col2": null + } + })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + w.append(r1).unwrap(); + w.append(r2).unwrap(); + w.append(r3).unwrap(); + let bytes = w.into_inner().unwrap(); + + let mut reader = ReaderBuilder::new() + .read_schema() + .with_batch_size(3) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 3); + assert_eq!(batch.num_columns(), 1); + + let expected = [ + "+---------------+", + "| col1 |", + "+---------------+", + "| |", + "| {col2: hello} |", + "| {col2: } |", + "+---------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + #[test] fn test_avro_iterator() { let reader = build_reader("alltypes_plain.avro", 5); diff --git a/datafusion/core/src/datasource/avro_to_arrow/schema.rs b/datafusion/core/src/datasource/avro_to_arrow/schema.rs index f15e378cc699..761e6b62680f 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/schema.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/schema.rs @@ -35,7 +35,7 @@ pub fn to_arrow_schema(avro_schema: &apache_avro::Schema) -> Result { schema_fields.push(schema_to_field_with_props( &field.schema, Some(&field.name), - false, + field.is_nullable(), Some(external_props(&field.schema)), )?) } @@ -73,7 +73,7 @@ fn schema_to_field_with_props( AvroSchema::Bytes => DataType::Binary, AvroSchema::String => DataType::Utf8, AvroSchema::Array(item_schema) => DataType::List(Arc::new( - schema_to_field_with_props(item_schema, None, false, None)?, + schema_to_field_with_props(item_schema, Some("element"), false, None)?, )), AvroSchema::Map(value_schema) => { let value_field = @@ -116,7 +116,7 @@ fn schema_to_field_with_props( DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense) } } - AvroSchema::Record(RecordSchema { name, fields, .. }) => { + AvroSchema::Record(RecordSchema { fields, .. }) => { let fields: Result<_> = fields .iter() .map(|field| { @@ -129,7 +129,7 @@ fn schema_to_field_with_props( }*/ schema_to_field_with_props( &field.schema, - Some(&format!("{}.{}", name.fullname(None), field.name)), + Some(&field.name), false, Some(props), ) @@ -442,6 +442,58 @@ mod test { assert_eq!(arrow_schema.unwrap(), expected); } + #[test] + fn test_nested_schema() { + let avro_schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "r1", + "fields": [ + { + "name": "col1", + "type": [ + "null", + { + "type": "record", + "name": "r2", + "fields": [ + { + "name": "col2", + "type": "string" + }, + { + "name": "col3", + "type": ["null", "string"], + "default": null + } + ] + } + ], + "default": null + } + ] + }"#, + ) + .unwrap(); + // should not use Avro Record names. + let expected_arrow_schema = Schema::new(vec![Field::new( + "col1", + arrow::datatypes::DataType::Struct( + vec![ + Field::new("col2", Utf8, false), + Field::new("col3", Utf8, true), + ] + .into(), + ), + true, + )]); + assert_eq!( + to_arrow_schema(&avro_schema).unwrap(), + expected_arrow_schema + ); + } + #[test] fn test_non_record_schema() { let arrow_schema = to_arrow_schema(&AvroSchema::String); diff --git a/datafusion/core/src/datasource/default_table_source.rs b/datafusion/core/src/datasource/default_table_source.rs index 58d0997bb653..fadf01c74c5d 100644 --- a/datafusion/core/src/datasource/default_table_source.rs +++ b/datafusion/core/src/datasource/default_table_source.rs @@ -26,10 +26,12 @@ use arrow::datatypes::SchemaRef; use datafusion_common::{internal_err, Constraints, DataFusionError}; use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource}; -/// DataFusion default table source, wrapping TableProvider +/// DataFusion default table source, wrapping TableProvider. /// /// This structure adapts a `TableProvider` (physical plan trait) to the `TableSource` -/// (logical plan trait) +/// (logical plan trait) and is necessary because the logical plan is contained in +/// the `datafusion_expr` crate, and is not aware of table providers, which exist in +/// the core `datafusion` crate. pub struct DefaultTableSource { /// table provider pub table_provider: Arc, @@ -43,7 +45,7 @@ impl DefaultTableSource { } impl TableSource for DefaultTableSource { - /// Returns the table source as [`Any`](std::any::Any) so that it can be + /// Returns the table source as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any { self @@ -71,6 +73,10 @@ impl TableSource for DefaultTableSource { fn get_logical_plan(&self) -> Option<&datafusion_expr::LogicalPlan> { self.table_provider.get_logical_plan() } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.table_provider.get_column_default(column) + } } /// Wrap TableProvider in TableSource diff --git a/datafusion/core/src/datasource/empty.rs b/datafusion/core/src/datasource/empty.rs index 77160aa5d1c0..5100987520ee 100644 --- a/datafusion/core/src/datasource/empty.rs +++ b/datafusion/core/src/datasource/empty.rs @@ -77,7 +77,7 @@ impl TableProvider for EmptyTable { // even though there is no data, projections apply let projected_schema = project_schema(&self.schema, projection)?; Ok(Arc::new( - EmptyExec::new(false, projected_schema).with_partitions(self.partitions), + EmptyExec::new(projected_schema).with_partitions(self.partitions), )) } } diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 45ecdd6083e7..650f8c844eda 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -15,24 +15,55 @@ // specific language governing permissions and limitations // under the License. -//! Apache Arrow format abstractions +//! [`ArrowFormat`]: Apache Arrow [`FileFormat`] abstractions //! //! Works with files following the [Arrow IPC format](https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format) +use std::any::Any; +use std::borrow::Cow; +use std::fmt::{self, Debug}; +use std::sync::Arc; + use crate::datasource::file_format::FileFormat; -use crate::datasource::physical_plan::{ArrowExec, FileScanConfig}; +use crate::datasource::physical_plan::{ + ArrowExec, FileGroupDisplay, FileScanConfig, FileSinkConfig, +}; use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::ExecutionPlan; + +use arrow::ipc::convert::fb_to_schema; use arrow::ipc::reader::FileReader; -use arrow_schema::{Schema, SchemaRef}; +use arrow::ipc::root_as_message; +use arrow_ipc::writer::IpcWriteOptions; +use arrow_ipc::CompressionType; +use arrow_schema::{ArrowError, Schema, SchemaRef}; + +use bytes::Bytes; +use datafusion_common::{not_impl_err, DataFusionError, FileType, Statistics}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; + +use crate::physical_plan::{DisplayAs, DisplayFormatType}; use async_trait::async_trait; -use datafusion_common::{FileType, Statistics}; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_plan::insert::{DataSink, FileSinkExec}; +use datafusion_physical_plan::metrics::MetricsSet; +use futures::stream::BoxStream; +use futures::StreamExt; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; -use std::any::Any; -use std::io::{Read, Seek}; -use std::sync::Arc; +use tokio::io::AsyncWriteExt; +use tokio::task::JoinSet; + +use super::file_compression_type::FileCompressionType; +use super::write::demux::start_demuxer_task; +use super::write::{create_writer, SharedBuffer}; + +/// Initial writing buffer size. Note this is just a size hint for efficiency. It +/// will grow beyond the set value if needed. +const INITIAL_BUFFER_BYTES: usize = 1048576; + +/// If the buffered Arrow data exceeds this size, it is flushed to object store +const BUFFER_FLUSH_BYTES: usize = 1024000; /// Arrow `FileFormat` implementation. #[derive(Default, Debug)] @@ -55,13 +86,11 @@ impl FileFormat for ArrowFormat { let r = store.as_ref().get(&object.location).await?; let schema = match r.payload { GetResultPayload::File(mut file, _) => { - read_arrow_schema_from_reader(&mut file)? + let reader = FileReader::try_new(&mut file, None)?; + reader.schema() } - GetResultPayload::Stream(_) => { - // TODO: Fetching entire file to get schema is potentially wasteful - let data = r.bytes().await?; - let mut cursor = std::io::Cursor::new(&data); - read_arrow_schema_from_reader(&mut cursor)? + GetResultPayload::Stream(stream) => { + infer_schema_from_file_stream(stream).await? } }; schemas.push(schema.as_ref().clone()); @@ -74,10 +103,10 @@ impl FileFormat for ArrowFormat { &self, _state: &SessionState, _store: &Arc, - _table_schema: SchemaRef, + table_schema: SchemaRef, _object: &ObjectMeta, ) -> Result { - Ok(Statistics::default()) + Ok(Statistics::new_unknown(&table_schema)) } async fn create_physical_plan( @@ -90,12 +119,372 @@ impl FileFormat for ArrowFormat { Ok(Arc::new(exec)) } + async fn create_writer_physical_plan( + &self, + input: Arc, + _state: &SessionState, + conf: FileSinkConfig, + order_requirements: Option>, + ) -> Result> { + if conf.overwrite { + return not_impl_err!("Overwrites are not implemented yet for Arrow format"); + } + + let sink_schema = conf.output_schema().clone(); + let sink = Arc::new(ArrowFileSink::new(conf)); + + Ok(Arc::new(FileSinkExec::new( + input, + sink, + sink_schema, + order_requirements, + )) as _) + } + fn file_type(&self) -> FileType { FileType::ARROW } } -fn read_arrow_schema_from_reader(reader: R) -> Result { - let reader = FileReader::try_new(reader, None)?; - Ok(reader.schema()) +/// Implements [`DataSink`] for writing to arrow_ipc files +struct ArrowFileSink { + config: FileSinkConfig, +} + +impl ArrowFileSink { + fn new(config: FileSinkConfig) -> Self { + Self { config } + } + + /// Converts table schema to writer schema, which may differ in the case + /// of hive style partitioning where some columns are removed from the + /// underlying files. + fn get_writer_schema(&self) -> Arc { + if !self.config.table_partition_cols.is_empty() { + let schema = self.config.output_schema(); + let partition_names: Vec<_> = self + .config + .table_partition_cols + .iter() + .map(|(s, _)| s) + .collect(); + Arc::new(Schema::new( + schema + .fields() + .iter() + .filter(|f| !partition_names.contains(&f.name())) + .map(|f| (**f).clone()) + .collect::>(), + )) + } else { + self.config.output_schema().clone() + } + } +} + +impl Debug for ArrowFileSink { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ArrowFileSink").finish() + } +} + +impl DisplayAs for ArrowFileSink { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ArrowFileSink(file_groups=",)?; + FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; + write!(f, ")") + } + } + } +} + +#[async_trait] +impl DataSink for ArrowFileSink { + fn as_any(&self) -> &dyn Any { + self + } + + fn metrics(&self) -> Option { + None + } + + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + // No props are supported yet, but can be by updating FileTypeWriterOptions + // to populate this struct and use those options to initialize the arrow_ipc::writer::FileWriter + // https://github.com/apache/arrow-datafusion/issues/8635 + let _arrow_props = self.config.file_type_writer_options.try_into_arrow()?; + + let object_store = context + .runtime_env() + .object_store(&self.config.object_store_url)?; + + let part_col = if !self.config.table_partition_cols.is_empty() { + Some(self.config.table_partition_cols.clone()) + } else { + None + }; + + let (demux_task, mut file_stream_rx) = start_demuxer_task( + data, + context, + part_col, + self.config.table_paths[0].clone(), + "arrow".into(), + self.config.single_file_output, + ); + + let mut file_write_tasks: JoinSet> = + JoinSet::new(); + + let ipc_options = + IpcWriteOptions::try_new(64, false, arrow_ipc::MetadataVersion::V5)? + .try_with_compression(Some(CompressionType::LZ4_FRAME))?; + while let Some((path, mut rx)) = file_stream_rx.recv().await { + let shared_buffer = SharedBuffer::new(INITIAL_BUFFER_BYTES); + let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options( + shared_buffer.clone(), + &self.get_writer_schema(), + ipc_options.clone(), + )?; + let mut object_store_writer = create_writer( + FileCompressionType::UNCOMPRESSED, + &path, + object_store.clone(), + ) + .await?; + file_write_tasks.spawn(async move { + let mut row_count = 0; + while let Some(batch) = rx.recv().await { + row_count += batch.num_rows(); + arrow_writer.write(&batch)?; + let mut buff_to_flush = shared_buffer.buffer.try_lock().unwrap(); + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { + object_store_writer + .write_all(buff_to_flush.as_slice()) + .await?; + buff_to_flush.clear(); + } + } + arrow_writer.finish()?; + let final_buff = shared_buffer.buffer.try_lock().unwrap(); + + object_store_writer.write_all(final_buff.as_slice()).await?; + object_store_writer.shutdown().await?; + Ok(row_count) + }); + } + + let mut row_count = 0; + while let Some(result) = file_write_tasks.join_next().await { + match result { + Ok(r) => { + row_count += r?; + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + match demux_task.await { + Ok(r) => r?, + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + Ok(row_count as u64) + } +} + +const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; +const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; + +/// Custom implementation of inferring schema. Should eventually be moved upstream to arrow-rs. +/// See +async fn infer_schema_from_file_stream( + mut stream: BoxStream<'static, object_store::Result>, +) -> Result { + // Expected format: + // - 6 bytes + // - 2 bytes + // - 4 bytes, not present below v0.15.0 + // - 4 bytes + // + // + + // So in first read we need at least all known sized sections, + // which is 6 + 2 + 4 + 4 = 16 bytes. + let bytes = collect_at_least_n_bytes(&mut stream, 16, None).await?; + + // Files should start with these magic bytes + if bytes[0..6] != ARROW_MAGIC { + return Err(ArrowError::ParseError( + "Arrow file does not contian correct header".to_string(), + ))?; + } + + // Since continuation marker bytes added in later versions + let (meta_len, rest_of_bytes_start_index) = if bytes[8..12] == CONTINUATION_MARKER { + (&bytes[12..16], 16) + } else { + (&bytes[8..12], 12) + }; + + let meta_len = [meta_len[0], meta_len[1], meta_len[2], meta_len[3]]; + let meta_len = i32::from_le_bytes(meta_len); + + // Read bytes for Schema message + let block_data = if bytes[rest_of_bytes_start_index..].len() < meta_len as usize { + // Need to read more bytes to decode Message + let mut block_data = Vec::with_capacity(meta_len as usize); + // In case we had some spare bytes in our initial read chunk + block_data.extend_from_slice(&bytes[rest_of_bytes_start_index..]); + let size_to_read = meta_len as usize - block_data.len(); + let block_data = + collect_at_least_n_bytes(&mut stream, size_to_read, Some(block_data)).await?; + Cow::Owned(block_data) + } else { + // Already have the bytes we need + let end_index = meta_len as usize + rest_of_bytes_start_index; + let block_data = &bytes[rest_of_bytes_start_index..end_index]; + Cow::Borrowed(block_data) + }; + + // Decode Schema message + let message = root_as_message(&block_data).map_err(|err| { + ArrowError::ParseError(format!("Unable to read IPC message as metadata: {err:?}")) + })?; + let ipc_schema = message.header_as_schema().ok_or_else(|| { + ArrowError::IpcError("Unable to read IPC message as schema".to_string()) + })?; + let schema = fb_to_schema(ipc_schema); + + Ok(Arc::new(schema)) +} + +async fn collect_at_least_n_bytes( + stream: &mut BoxStream<'static, object_store::Result>, + n: usize, + extend_from: Option>, +) -> Result> { + let mut buf = extend_from.unwrap_or_else(|| Vec::with_capacity(n)); + // If extending existing buffer then ensure we read n additional bytes + let n = n + buf.len(); + while let Some(bytes) = stream.next().await.transpose()? { + buf.extend_from_slice(&bytes); + if buf.len() >= n { + break; + } + } + if buf.len() < n { + return Err(ArrowError::ParseError( + "Unexpected end of byte stream for Arrow IPC file".to_string(), + ))?; + } + Ok(buf) +} + +#[cfg(test)] +mod tests { + use chrono::DateTime; + use object_store::{chunked::ChunkedStore, memory::InMemory, path::Path}; + + use crate::execution::context::SessionContext; + + use super::*; + + #[tokio::test] + async fn test_infer_schema_stream() -> Result<()> { + let mut bytes = std::fs::read("tests/data/example.arrow")?; + bytes.truncate(bytes.len() - 20); // mangle end to show we don't need to read whole file + let location = Path::parse("example.arrow")?; + let in_memory_store: Arc = Arc::new(InMemory::new()); + in_memory_store.put(&location, bytes.into()).await?; + + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + let object_meta = ObjectMeta { + location, + last_modified: DateTime::default(), + size: usize::MAX, + e_tag: None, + version: None, + }; + + let arrow_format = ArrowFormat {}; + let expected = vec!["f0: Int64", "f1: Utf8", "f2: Boolean"]; + + // Test chunk sizes where too small so we keep having to read more bytes + // And when large enough that first read contains all we need + for chunk_size in [7, 3000] { + let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), chunk_size)); + let inferred_schema = arrow_format + .infer_schema( + &state, + &(store.clone() as Arc), + &[object_meta.clone()], + ) + .await?; + let actual_fields = inferred_schema + .fields() + .iter() + .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .collect::>(); + assert_eq!(expected, actual_fields); + } + + Ok(()) + } + + #[tokio::test] + async fn test_infer_schema_short_stream() -> Result<()> { + let mut bytes = std::fs::read("tests/data/example.arrow")?; + bytes.truncate(20); // should cause error that file shorter than expected + let location = Path::parse("example.arrow")?; + let in_memory_store: Arc = Arc::new(InMemory::new()); + in_memory_store.put(&location, bytes.into()).await?; + + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + let object_meta = ObjectMeta { + location, + last_modified: DateTime::default(), + size: usize::MAX, + e_tag: None, + version: None, + }; + + let arrow_format = ArrowFormat {}; + + let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), 7)); + let err = arrow_format + .infer_schema( + &state, + &(store.clone() as Arc), + &[object_meta.clone()], + ) + .await; + + assert!(err.is_err()); + assert_eq!( + "Arrow error: Parser error: Unexpected end of byte stream for Arrow IPC file", + err.unwrap_err().to_string() + ); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index e68a4cad2207..6d424bf0b28f 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Apache Avro format abstractions +//! [`AvroFormat`] Apache Avro [`FileFormat`] abstractions use std::any::Any; use std::sync::Arc; @@ -74,10 +74,10 @@ impl FileFormat for AvroFormat { &self, _state: &SessionState, _store: &Arc, - _table_schema: SchemaRef, + table_schema: SchemaRef, _object: &ObjectMeta, ) -> Result { - Ok(Statistics::default()) + Ok(Statistics::new_unknown(&table_schema)) } async fn create_physical_plan( @@ -112,7 +112,7 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { let config = SessionConfig::new().with_batch_size(2); - let session_ctx = SessionContext::with_config(config); + let session_ctx = SessionContext::new_with_config(config); let state = session_ctx.state(); let task_ctx = state.task_ctx(); let projection = None; diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 897174659e13..7a0af3ff0809 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -15,21 +15,34 @@ // specific language governing permissions and limitations // under the License. -//! CSV format abstractions +//! [`CsvFormat`], Comma Separated Value (CSV) [`FileFormat`] abstractions use std::any::Any; use std::collections::HashSet; -use std::fmt; -use std::fmt::Debug; +use std::fmt::{self, Debug}; use std::sync::Arc; +use super::write::orchestration::stateless_multipart_put; +use super::{FileFormat, DEFAULT_SCHEMA_INFER_MAX_RECORD}; +use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::file_format::write::BatchSerializer; +use crate::datasource::physical_plan::{ + CsvExec, FileGroupDisplay, FileScanConfig, FileSinkConfig, +}; +use crate::error::Result; +use crate::execution::context::SessionState; +use crate::physical_plan::insert::{DataSink, FileSinkExec}; +use crate::physical_plan::{DisplayAs, DisplayFormatType, Statistics}; +use crate::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; + +use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use arrow::datatypes::{DataType, Field, Fields, Schema}; use arrow::{self, datatypes::SchemaRef}; -use arrow_array::RecordBatch; use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; use bytes::{Buf, Bytes}; @@ -37,21 +50,6 @@ use futures::stream::BoxStream; use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore}; -use super::{FileFormat, DEFAULT_SCHEMA_INFER_MAX_RECORD}; -use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::file_format::write::{ - create_writer, stateless_serialize_and_write_files, BatchSerializer, FileWriterMode, -}; -use crate::datasource::physical_plan::{ - CsvExec, FileGroupDisplay, FileScanConfig, FileSinkConfig, -}; -use crate::error::Result; -use crate::execution::context::SessionState; -use crate::physical_plan::insert::{DataSink, FileSinkExec}; -use crate::physical_plan::{DisplayAs, DisplayFormatType, Statistics}; -use crate::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; -use rand::distributions::{Alphanumeric, DistString}; - /// Character Separated Value `FileFormat` implementation. #[derive(Debug)] pub struct CsvFormat { @@ -235,10 +233,10 @@ impl FileFormat for CsvFormat { &self, _state: &SessionState, _store: &Arc, - _table_schema: SchemaRef, + table_schema: SchemaRef, _object: &ObjectMeta, ) -> Result { - Ok(Statistics::default()) + Ok(Statistics::new_unknown(&table_schema)) } async fn create_physical_plan( @@ -263,6 +261,7 @@ impl FileFormat for CsvFormat { input: Arc, _state: &SessionState, conf: FileSinkConfig, + order_requirements: Option>, ) -> Result> { if conf.overwrite { return not_impl_err!("Overwrites are not implemented yet for CSV"); @@ -275,7 +274,12 @@ impl FileFormat for CsvFormat { let sink_schema = conf.output_schema().clone(); let sink = Arc::new(CsvSink::new(conf)); - Ok(Arc::new(FileSinkExec::new(input, sink, sink_schema)) as _) + Ok(Arc::new(FileSinkExec::new( + input, + sink, + sink_schema, + order_requirements, + )) as _) } fn file_type(&self) -> FileType { @@ -393,8 +397,6 @@ impl Default for CsvSerializer { pub struct CsvSerializer { // CSV writer builder builder: WriterBuilder, - // Inner buffer for avoiding reallocation - buffer: Vec, // Flag to indicate whether there will be a header header: bool, } @@ -405,7 +407,6 @@ impl CsvSerializer { Self { builder: WriterBuilder::new(), header: true, - buffer: Vec::with_capacity(4096), } } @@ -424,26 +425,19 @@ impl CsvSerializer { #[async_trait] impl BatchSerializer for CsvSerializer { - async fn serialize(&mut self, batch: RecordBatch) -> Result { + async fn serialize(&self, batch: RecordBatch, initial: bool) -> Result { + let mut buffer = Vec::with_capacity(4096); let builder = self.builder.clone(); - let mut writer = builder.has_headers(self.header).build(&mut self.buffer); + let header = self.header && initial; + let mut writer = builder.with_header(header).build(&mut buffer); writer.write(&batch)?; drop(writer); - self.header = false; - Ok(Bytes::from(self.buffer.drain(..).collect::>())) - } - - fn duplicate(&mut self) -> Result> { - let new_self = CsvSerializer::new() - .with_builder(self.builder.clone()) - .with_header(self.header); - self.header = false; - Ok(Box::new(new_self)) + Ok(Bytes::from(buffer)) } } /// Implements [`DataSink`] for writing to a CSV file. -struct CsvSink { +pub struct CsvSink { /// Config options for writing data config: FileSinkConfig, } @@ -458,11 +452,7 @@ impl DisplayAs for CsvSink { fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!( - f, - "CsvSink(writer_mode={:?}, file_groups=", - self.config.writer_mode - )?; + write!(f, "CsvSink(file_groups=",)?; FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; write!(f, ")") } @@ -471,128 +461,66 @@ impl DisplayAs for CsvSink { } impl CsvSink { - fn new(config: FileSinkConfig) -> Self { + /// Create from config. + pub fn new(config: FileSinkConfig) -> Self { Self { config } } -} -#[async_trait] -impl DataSink for CsvSink { - async fn write_all( + /// Retrieve the inner [`FileSinkConfig`]. + pub fn config(&self) -> &FileSinkConfig { + &self.config + } + + async fn multipartput_all( &self, - data: Vec, + data: SendableRecordBatchStream, context: &Arc, ) -> Result { - let num_partitions = data.len(); let writer_options = self.config.file_type_writer_options.try_into_csv()?; - let (builder, compression) = - (&writer_options.writer_options, &writer_options.compression); - let mut has_header = writer_options.has_header; - let compression = FileCompressionType::from(*compression); - - let object_store = context - .runtime_env() - .object_store(&self.config.object_store_url)?; - // Construct serializer and writer for each file group - let mut serializers: Vec> = vec![]; - let mut writers = vec![]; - match self.config.writer_mode { - FileWriterMode::Append => { - for file_group in &self.config.file_groups { - let mut append_builder = builder.clone(); - // In append mode, consider has_header flag only when file is empty (at the start). - // For other modes, use has_header flag as is. - if file_group.object_meta.size != 0 { - has_header = false; - append_builder = append_builder.has_headers(false); - } - let serializer = CsvSerializer::new() - .with_builder(append_builder) - .with_header(has_header); - serializers.push(Box::new(serializer)); - - let file = file_group.clone(); - let writer = create_writer( - self.config.writer_mode, - compression, - file.object_meta.clone().into(), - object_store.clone(), - ) - .await?; - writers.push(writer); - } - } - FileWriterMode::Put => { - return not_impl_err!("Put Mode is not implemented for CSV Sink yet") - } - FileWriterMode::PutMultipart => { - // Currently assuming only 1 partition path (i.e. not hive-style partitioning on a column) - let base_path = &self.config.table_paths[0]; - match self.config.single_file_output { - false => { - // Uniquely identify this batch of files with a random string, to prevent collisions overwriting files - let write_id = - Alphanumeric.sample_string(&mut rand::thread_rng(), 16); - for part_idx in 0..num_partitions { - let serializer = CsvSerializer::new() - .with_builder(builder.clone()) - .with_header(has_header); - serializers.push(Box::new(serializer)); - let file_path = base_path - .prefix() - .child(format!("{}_{}.csv", write_id, part_idx)); - let object_meta = ObjectMeta { - location: file_path, - last_modified: chrono::offset::Utc::now(), - size: 0, - e_tag: None, - }; - let writer = create_writer( - self.config.writer_mode, - compression, - object_meta.into(), - object_store.clone(), - ) - .await?; - writers.push(writer); - } - } - true => { - let serializer = CsvSerializer::new() - .with_builder(builder.clone()) - .with_header(has_header); - serializers.push(Box::new(serializer)); - let file_path = base_path.prefix(); - let object_meta = ObjectMeta { - location: file_path.clone(), - last_modified: chrono::offset::Utc::now(), - size: 0, - e_tag: None, - }; - let writer = create_writer( - self.config.writer_mode, - compression, - object_meta.into(), - object_store.clone(), - ) - .await?; - writers.push(writer); - } - } - } - } + let builder = &writer_options.writer_options; + + let builder_clone = builder.clone(); + let options_clone = writer_options.clone(); + let get_serializer = move || { + Arc::new( + CsvSerializer::new() + .with_builder(builder_clone.clone()) + .with_header(options_clone.writer_options.header()), + ) as _ + }; - stateless_serialize_and_write_files( + stateless_multipart_put( data, - serializers, - writers, - self.config.single_file_output, - self.config.unbounded_input, + context, + "csv".into(), + Box::new(get_serializer), + &self.config, + writer_options.compression.into(), ) .await } } +#[async_trait] +impl DataSink for CsvSink { + fn as_any(&self) -> &dyn Any { + self + } + + fn metrics(&self) -> Option { + None + } + + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + let total_count = self.multipartput_all(data, context).await?; + Ok(total_count) + } +} + #[cfg(test)] mod tests { use super::super::test_util::scan_format; @@ -605,14 +533,15 @@ mod tests { use crate::physical_plan::collect; use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext}; use crate::test_util::arrow_test_data; + use arrow::compute::concat_batches; - use bytes::Bytes; - use chrono::DateTime; use datafusion_common::cast::as_string_array; - use datafusion_common::internal_err; - use datafusion_common::FileType; - use datafusion_common::GetExt; + use datafusion_common::stats::Precision; + use datafusion_common::{internal_err, FileType, GetExt}; use datafusion_expr::{col, lit}; + + use bytes::Bytes; + use chrono::DateTime; use futures::StreamExt; use object_store::local::LocalFileSystem; use object_store::path::Path; @@ -622,7 +551,7 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { let config = SessionConfig::new().with_batch_size(2); - let session_ctx = SessionContext::with_config(config); + let session_ctx = SessionContext::new_with_config(config); let state = session_ctx.state(); let task_ctx = state.task_ctx(); // skip column 9 that overflows the automaticly discovered column type of i64 (u64 would work) @@ -642,8 +571,8 @@ mod tests { assert_eq!(tt_batches, 50 /* 100/2 */); // test metadata - assert_eq!(exec.statistics().num_rows, None); - assert_eq!(exec.statistics().total_byte_size, None); + assert_eq!(exec.statistics()?.num_rows, Precision::Absent); + assert_eq!(exec.statistics()?.total_byte_size, Precision::Absent); Ok(()) } @@ -736,6 +665,7 @@ mod tests { last_modified: DateTime::default(), size: usize::MAX, e_tag: None, + version: None, }; let num_rows_to_read = 100; @@ -898,8 +828,8 @@ mod tests { .collect() .await?; let batch = concat_batches(&batches[0].schema(), &batches)?; - let mut serializer = CsvSerializer::new(); - let bytes = serializer.serialize(batch).await?; + let serializer = CsvSerializer::new(); + let bytes = serializer.serialize(batch, true).await?; assert_eq!( "c2,c3\n2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n", String::from_utf8(bytes.into()).unwrap() @@ -922,8 +852,8 @@ mod tests { .collect() .await?; let batch = concat_batches(&batches[0].schema(), &batches)?; - let mut serializer = CsvSerializer::new().with_header(false); - let bytes = serializer.serialize(batch).await?; + let serializer = CsvSerializer::new().with_header(false); + let bytes = serializer.serialize(batch, true).await?; assert_eq!( "2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n", String::from_utf8(bytes.into()).unwrap() @@ -960,7 +890,7 @@ mod tests { .with_repartition_file_scans(true) .with_repartition_file_min_size(0) .with_target_partitions(n_partitions); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let testdata = arrow_test_data(); ctx.register_csv( "aggr", @@ -997,7 +927,7 @@ mod tests { .has_header(true) .file_compression_type(FileCompressionType::GZIP) .file_extension("csv.gz"); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let testdata = arrow_test_data(); ctx.register_csv( "aggr", @@ -1033,7 +963,7 @@ mod tests { .with_repartition_file_scans(true) .with_repartition_file_min_size(0) .with_target_partitions(n_partitions); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); ctx.register_csv( "empty", "tests/data/empty_0_byte.csv", @@ -1066,7 +996,7 @@ mod tests { .with_repartition_file_scans(true) .with_repartition_file_min_size(0) .with_target_partitions(n_partitions); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); ctx.register_csv( "empty", "tests/data/empty.csv", @@ -1104,7 +1034,7 @@ mod tests { .with_repartition_file_scans(true) .with_repartition_file_min_size(0) .with_target_partitions(n_partitions); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let file_format = CsvFormat::default().with_has_header(false); let listing_options = ListingOptions::new(Arc::new(file_format)) .with_file_extension(FileType::CSV.get_ext()); @@ -1157,7 +1087,7 @@ mod tests { .with_repartition_file_scans(true) .with_repartition_file_min_size(0) .with_target_partitions(n_partitions); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let file_format = CsvFormat::default().with_has_header(false); let listing_options = ListingOptions::new(Arc::new(file_format)) .with_file_extension(FileType::CSV.get_ext()); @@ -1202,7 +1132,7 @@ mod tests { .with_repartition_file_scans(true) .with_repartition_file_min_size(0) .with_target_partitions(n_partitions); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); ctx.register_csv( "one_col", @@ -1251,7 +1181,7 @@ mod tests { .with_repartition_file_scans(true) .with_repartition_file_min_size(0) .with_target_partitions(n_partitions); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); ctx.register_csv( "wide_rows", "tests/data/wide_rows.csv", diff --git a/datafusion/core/src/datasource/file_format/file_compression_type.rs b/datafusion/core/src/datasource/file_format/file_compression_type.rs index bd2868767090..3dac7c293050 100644 --- a/datafusion/core/src/datasource/file_format/file_compression_type.rs +++ b/datafusion/core/src/datasource/file_format/file_compression_type.rs @@ -237,7 +237,14 @@ impl FileTypeExt for FileType { match self { FileType::JSON | FileType::CSV => Ok(format!("{}{}", ext, c.get_ext())), - FileType::PARQUET | FileType::AVRO | FileType::ARROW => match c.variant { + FileType::AVRO | FileType::ARROW => match c.variant { + UNCOMPRESSED => Ok(ext), + _ => Err(DataFusionError::Internal( + "FileCompressionType can be specified for CSV/JSON FileType.".into(), + )), + }, + #[cfg(feature = "parquet")] + FileType::PARQUET => match c.variant { UNCOMPRESSED => Ok(ext), _ => Err(DataFusionError::Internal( "FileCompressionType can be specified for CSV/JSON FileType.".into(), @@ -276,10 +283,13 @@ mod tests { ); } + let mut ty_ext_tuple = vec![]; + ty_ext_tuple.push((FileType::AVRO, ".avro")); + #[cfg(feature = "parquet")] + ty_ext_tuple.push((FileType::PARQUET, ".parquet")); + // Cannot specify compression for these file types - for (file_type, extension) in - [(FileType::AVRO, ".avro"), (FileType::PARQUET, ".parquet")] - { + for (file_type, extension) in ty_ext_tuple { assert_eq!( file_type .get_ext_with_compression(FileCompressionType::UNCOMPRESSED) diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index c715317a9527..8c02955ad363 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -15,53 +15,43 @@ // specific language governing permissions and limitations // under the License. -//! Line delimited JSON format abstractions +//! [`JsonFormat`]: Line delimited JSON [`FileFormat`] abstractions use std::any::Any; - -use bytes::Bytes; -use datafusion_common::not_impl_err; -use datafusion_common::DataFusionError; -use datafusion_common::FileType; -use datafusion_execution::TaskContext; -use rand::distributions::Alphanumeric; -use rand::distributions::DistString; use std::fmt; use std::fmt::Debug; use std::io::BufReader; use std::sync::Arc; +use super::write::orchestration::stateless_multipart_put; +use super::{FileFormat, FileScanConfig}; +use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::file_format::write::BatchSerializer; +use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; +use crate::datasource::physical_plan::FileGroupDisplay; +use crate::datasource::physical_plan::{FileSinkConfig, NdJsonExec}; +use crate::error::Result; +use crate::execution::context::SessionState; +use crate::physical_plan::insert::{DataSink, FileSinkExec}; +use crate::physical_plan::{ + DisplayAs, DisplayFormatType, SendableRecordBatchStream, Statistics, +}; + use arrow::datatypes::Schema; use arrow::datatypes::SchemaRef; use arrow::json; -use arrow::json::reader::infer_json_schema_from_iterator; -use arrow::json::reader::ValueIter; +use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter}; use arrow_array::RecordBatch; -use async_trait::async_trait; -use bytes::Buf; +use datafusion_common::{not_impl_err, DataFusionError, FileType}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::ExecutionPlan; -use datafusion_physical_expr::PhysicalExpr; +use async_trait::async_trait; +use bytes::{Buf, Bytes}; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; -use crate::datasource::physical_plan::FileGroupDisplay; -use crate::physical_plan::insert::DataSink; -use crate::physical_plan::insert::FileSinkExec; -use crate::physical_plan::SendableRecordBatchStream; -use crate::physical_plan::{DisplayAs, DisplayFormatType, Statistics}; - -use super::FileFormat; -use super::FileScanConfig; -use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::file_format::write::{ - create_writer, stateless_serialize_and_write_files, BatchSerializer, FileWriterMode, -}; -use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; -use crate::datasource::physical_plan::FileSinkConfig; -use crate::datasource::physical_plan::NdJsonExec; -use crate::error::Result; -use crate::execution::context::SessionState; -use crate::physical_plan::ExecutionPlan; - /// New line delimited JSON `FileFormat` implementation. #[derive(Debug)] pub struct JsonFormat { @@ -152,10 +142,10 @@ impl FileFormat for JsonFormat { &self, _state: &SessionState, _store: &Arc, - _table_schema: SchemaRef, + table_schema: SchemaRef, _object: &ObjectMeta, ) -> Result { - Ok(Statistics::default()) + Ok(Statistics::new_unknown(&table_schema)) } async fn create_physical_plan( @@ -173,6 +163,7 @@ impl FileFormat for JsonFormat { input: Arc, _state: &SessionState, conf: FileSinkConfig, + order_requirements: Option>, ) -> Result> { if conf.overwrite { return not_impl_err!("Overwrites are not implemented yet for Json"); @@ -182,9 +173,14 @@ impl FileFormat for JsonFormat { return not_impl_err!("Inserting compressed JSON is not implemented yet."); } let sink_schema = conf.output_schema().clone(); - let sink = Arc::new(JsonSink::new(conf, self.file_compression_type)); - - Ok(Arc::new(FileSinkExec::new(input, sink, sink_schema)) as _) + let sink = Arc::new(JsonSink::new(conf)); + + Ok(Arc::new(FileSinkExec::new( + input, + sink, + sink_schema, + order_requirements, + )) as _) } fn file_type(&self) -> FileType { @@ -199,46 +195,34 @@ impl Default for JsonSerializer { } /// Define a struct for serializing Json records to a stream -pub struct JsonSerializer { - // Inner buffer for avoiding reallocation - buffer: Vec, -} +pub struct JsonSerializer {} impl JsonSerializer { /// Constructor for the JsonSerializer object pub fn new() -> Self { - Self { - buffer: Vec::with_capacity(4096), - } + Self {} } } #[async_trait] impl BatchSerializer for JsonSerializer { - async fn serialize(&mut self, batch: RecordBatch) -> Result { - let mut writer = json::LineDelimitedWriter::new(&mut self.buffer); + async fn serialize(&self, batch: RecordBatch, _initial: bool) -> Result { + let mut buffer = Vec::with_capacity(4096); + let mut writer = json::LineDelimitedWriter::new(&mut buffer); writer.write(&batch)?; - //drop(writer); - Ok(Bytes::from(self.buffer.drain(..).collect::>())) - } - - fn duplicate(&mut self) -> Result> { - Ok(Box::new(JsonSerializer::new())) + Ok(Bytes::from(buffer)) } } /// Implements [`DataSink`] for writing to a Json file. -struct JsonSink { +pub struct JsonSink { /// Config options for writing data config: FileSinkConfig, - file_compression_type: FileCompressionType, } impl Debug for JsonSink { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("JsonSink") - .field("file_compression_type", &self.file_compression_type) - .finish() + f.debug_struct("JsonSink").finish() } } @@ -246,11 +230,7 @@ impl DisplayAs for JsonSink { fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!( - f, - "JsonSink(writer_mode={:?}, file_groups=", - self.config.writer_mode - )?; + write!(f, "JsonSink(file_groups=",)?; FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; write!(f, ")") } @@ -259,129 +239,72 @@ impl DisplayAs for JsonSink { } impl JsonSink { - fn new(config: FileSinkConfig, file_compression_type: FileCompressionType) -> Self { - Self { - config, - file_compression_type, - } + /// Create from config. + pub fn new(config: FileSinkConfig) -> Self { + Self { config } } -} -#[async_trait] -impl DataSink for JsonSink { - async fn write_all( + /// Retrieve the inner [`FileSinkConfig`]. + pub fn config(&self) -> &FileSinkConfig { + &self.config + } + + async fn multipartput_all( &self, - data: Vec, + data: SendableRecordBatchStream, context: &Arc, ) -> Result { - let num_partitions = data.len(); - - let object_store = context - .runtime_env() - .object_store(&self.config.object_store_url)?; - let writer_options = self.config.file_type_writer_options.try_into_json()?; + let compression = &writer_options.compression; - let compression = FileCompressionType::from(writer_options.compression); + let get_serializer = move || Arc::new(JsonSerializer::new()) as _; - // Construct serializer and writer for each file group - let mut serializers: Vec> = vec![]; - let mut writers = vec![]; - match self.config.writer_mode { - FileWriterMode::Append => { - if self.config.single_file_output { - return Err(DataFusionError::NotImplemented("single_file_output=true is not implemented for JsonSink in Append mode".into())); - } - for file_group in &self.config.file_groups { - let serializer = JsonSerializer::new(); - serializers.push(Box::new(serializer)); - - let file = file_group.clone(); - let writer = create_writer( - self.config.writer_mode, - compression, - file.object_meta.clone().into(), - object_store.clone(), - ) - .await?; - writers.push(writer); - } - } - FileWriterMode::Put => { - return not_impl_err!("Put Mode is not implemented for Json Sink yet") - } - FileWriterMode::PutMultipart => { - // Currently assuming only 1 partition path (i.e. not hive-style partitioning on a column) - let base_path = &self.config.table_paths[0]; - match self.config.single_file_output { - false => { - // Uniquely identify this batch of files with a random string, to prevent collisions overwriting files - let write_id = - Alphanumeric.sample_string(&mut rand::thread_rng(), 16); - for part_idx in 0..num_partitions { - let serializer = JsonSerializer::new(); - serializers.push(Box::new(serializer)); - let file_path = base_path - .prefix() - .child(format!("{}_{}.json", write_id, part_idx)); - let object_meta = ObjectMeta { - location: file_path, - last_modified: chrono::offset::Utc::now(), - size: 0, - e_tag: None, - }; - let writer = create_writer( - self.config.writer_mode, - compression, - object_meta.into(), - object_store.clone(), - ) - .await?; - writers.push(writer); - } - } - true => { - let serializer = JsonSerializer::new(); - serializers.push(Box::new(serializer)); - let file_path = base_path.prefix(); - let object_meta = ObjectMeta { - location: file_path.clone(), - last_modified: chrono::offset::Utc::now(), - size: 0, - e_tag: None, - }; - let writer = create_writer( - self.config.writer_mode, - compression, - object_meta.into(), - object_store.clone(), - ) - .await?; - writers.push(writer); - } - } - } - } - - stateless_serialize_and_write_files( + stateless_multipart_put( data, - serializers, - writers, - self.config.single_file_output, - self.config.unbounded_input, + context, + "json".into(), + Box::new(get_serializer), + &self.config, + (*compression).into(), ) .await } } +#[async_trait] +impl DataSink for JsonSink { + fn as_any(&self) -> &dyn Any { + self + } + + fn metrics(&self) -> Option { + None + } + + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + let total_count = self.multipartput_all(data, context).await?; + Ok(total_count) + } +} + #[cfg(test)] mod tests { use super::super::test_util::scan_format; + use arrow::util::pretty; use datafusion_common::cast::as_int64_array; + use datafusion_common::stats::Precision; + use datafusion_common::{assert_batches_eq, internal_err}; use futures::StreamExt; use object_store::local::LocalFileSystem; + use regex::Regex; + use rstest::rstest; use super::*; + use crate::execution::options::NdJsonReadOptions; use crate::physical_plan::collect; use crate::prelude::{SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; @@ -389,7 +312,7 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { let config = SessionConfig::new().with_batch_size(2); - let session_ctx = SessionContext::with_config(config); + let session_ctx = SessionContext::new_with_config(config); let state = session_ctx.state(); let task_ctx = state.task_ctx(); let projection = None; @@ -408,8 +331,8 @@ mod tests { assert_eq!(tt_batches, 6 /* 12/2 */); // test metadata - assert_eq!(exec.statistics().num_rows, None); - assert_eq!(exec.statistics().total_byte_size, None); + assert_eq!(exec.statistics()?.num_rows, Precision::Absent); + assert_eq!(exec.statistics()?.total_byte_size, Precision::Absent); Ok(()) } @@ -505,4 +428,94 @@ mod tests { .collect::>(); assert_eq!(vec!["a: Int64", "b: Float64", "c: Boolean"], fields); } + + async fn count_num_partitions(ctx: &SessionContext, query: &str) -> Result { + let result = ctx + .sql(&format!("EXPLAIN {query}")) + .await? + .collect() + .await?; + + let plan = format!("{}", &pretty::pretty_format_batches(&result)?); + + let re = Regex::new(r"file_groups=\{(\d+) group").unwrap(); + + if let Some(captures) = re.captures(&plan) { + if let Some(match_) = captures.get(1) { + let count = match_.as_str().parse::().unwrap(); + return Ok(count); + } + } + + internal_err!("Query contains no Exec: file_groups") + } + + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[tokio::test] + async fn it_can_read_ndjson_in_parallel(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + + let ctx = SessionContext::new_with_config(config); + + let table_path = "tests/data/1.json"; + let options = NdJsonReadOptions::default(); + + ctx.register_json("json_parallel", table_path, options) + .await?; + + let query = "SELECT SUM(a) FROM json_parallel;"; + + let result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_num_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = [ + "+----------------------+", + "| SUM(json_parallel.a) |", + "+----------------------+", + "| -7 |", + "+----------------------+" + ]; + + assert_batches_eq!(expected, &result); + assert_eq!(n_partitions, actual_partitions); + + Ok(()) + } + + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[tokio::test] + async fn it_can_read_empty_ndjson_in_parallel(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + + let ctx = SessionContext::new_with_config(config); + + let table_path = "tests/data/empty.json"; + let options = NdJsonReadOptions::default(); + + ctx.register_json("json_parallel_empty", table_path, options) + .await?; + + let query = "SELECT * FROM json_parallel_empty WHERE random() > 0.5;"; + + let result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_num_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = [ + "++", + "++", + ]; + + assert_batches_eq!(expected, &result); + assert_eq!(1, actual_partitions); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 86f265ab9492..12c9fb91adb1 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -27,6 +27,7 @@ pub mod csv; pub mod file_compression_type; pub mod json; pub mod options; +#[cfg(feature = "parquet")] pub mod parquet; pub mod write; @@ -41,7 +42,7 @@ use crate::execution::context::SessionState; use crate::physical_plan::{ExecutionPlan, Statistics}; use datafusion_common::{not_impl_err, DataFusionError, FileType}; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use async_trait::async_trait; use object_store::{ObjectMeta, ObjectStore}; @@ -99,6 +100,7 @@ pub trait FileFormat: Send + Sync + fmt::Debug { _input: Arc, _state: &SessionState, _conf: FileSinkConfig, + _order_requirements: Option>, ) -> Result> { not_impl_err!("Writer not implemented for this format") } @@ -122,7 +124,8 @@ pub(crate) mod test_util { use object_store::local::LocalFileSystem; use object_store::path::Path; use object_store::{ - GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, + GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, PutOptions, + PutResult, }; use tokio::io::AsyncWrite; @@ -162,7 +165,6 @@ pub(crate) mod test_util { limit, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, ) @@ -187,7 +189,12 @@ pub(crate) mod test_util { #[async_trait] impl ObjectStore for VariableStream { - async fn put(&self, _location: &Path, _bytes: Bytes) -> object_store::Result<()> { + async fn put_opts( + &self, + _location: &Path, + _bytes: Bytes, + _opts: PutOptions, + ) -> object_store::Result { unimplemented!() } @@ -226,6 +233,7 @@ pub(crate) mod test_util { last_modified: Default::default(), size: range.end, e_tag: None, + version: None, }, range: Default::default(), }) @@ -255,11 +263,10 @@ pub(crate) mod test_util { unimplemented!() } - async fn list( + fn list( &self, _prefix: Option<&Path>, - ) -> object_store::Result>> - { + ) -> BoxStream<'_, object_store::Result> { unimplemented!() } diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 40d9878a0134..d389137785ff 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -21,16 +21,15 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema, SchemaRef}; use async_trait::async_trait; -use datafusion_common::{plan_err, DataFusionError}; use crate::datasource::file_format::arrow::ArrowFormat; use crate::datasource::file_format::file_compression_type::FileCompressionType; +#[cfg(feature = "parquet")] +use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; -use crate::datasource::listing::{ListingTableInsertMode, ListingTableUrl}; +use crate::datasource::listing::ListingTableUrl; use crate::datasource::{ - file_format::{ - avro::AvroFormat, csv::CsvFormat, json::JsonFormat, parquet::ParquetFormat, - }, + file_format::{avro::AvroFormat, csv::CsvFormat, json::JsonFormat}, listing::ListingOptions, }; use crate::error::Result; @@ -72,12 +71,8 @@ pub struct CsvReadOptions<'a> { pub table_partition_cols: Vec<(String, DataType)>, /// File compression type pub file_compression_type: FileCompressionType, - /// Flag indicating whether this file may be unbounded (as in a FIFO file). - pub infinite: bool, /// Indicates how the file is sorted pub file_sort_order: Vec>, - /// Setting controls how inserts to this file should be handled - pub insert_mode: ListingTableInsertMode, } impl<'a> Default for CsvReadOptions<'a> { @@ -99,9 +94,7 @@ impl<'a> CsvReadOptions<'a> { file_extension: DEFAULT_CSV_EXTENSION, table_partition_cols: vec![], file_compression_type: FileCompressionType::UNCOMPRESSED, - infinite: false, file_sort_order: vec![], - insert_mode: ListingTableInsertMode::AppendToFile, } } @@ -111,12 +104,6 @@ impl<'a> CsvReadOptions<'a> { self } - /// Configure mark_infinite setting - pub fn mark_infinite(mut self, infinite: bool) -> Self { - self.infinite = infinite; - self - } - /// Specify delimiter to use for CSV read pub fn delimiter(mut self, delimiter: u8) -> Self { self.delimiter = delimiter; @@ -184,12 +171,6 @@ impl<'a> CsvReadOptions<'a> { self.file_sort_order = file_sort_order; self } - - /// Configure how insertions to this table should be handled - pub fn insert_mode(mut self, insert_mode: ListingTableInsertMode) -> Self { - self.insert_mode = insert_mode; - self - } } /// Options that control the reading of Parquet files. @@ -219,8 +200,6 @@ pub struct ParquetReadOptions<'a> { pub schema: Option<&'a Schema>, /// Indicates how the file is sorted pub file_sort_order: Vec>, - /// Setting controls how inserts to this file should be handled - pub insert_mode: ListingTableInsertMode, } impl<'a> Default for ParquetReadOptions<'a> { @@ -232,7 +211,6 @@ impl<'a> Default for ParquetReadOptions<'a> { skip_metadata: None, schema: None, file_sort_order: vec![], - insert_mode: ListingTableInsertMode::AppendNewFiles, } } } @@ -272,12 +250,6 @@ impl<'a> ParquetReadOptions<'a> { self.file_sort_order = file_sort_order; self } - - /// Configure how insertions to this table should be handled - pub fn insert_mode(mut self, insert_mode: ListingTableInsertMode) -> Self { - self.insert_mode = insert_mode; - self - } } /// Options that control the reading of ARROW files. @@ -342,8 +314,6 @@ pub struct AvroReadOptions<'a> { pub file_extension: &'a str, /// Partition Columns pub table_partition_cols: Vec<(String, DataType)>, - /// Flag indicating whether this file may be unbounded (as in a FIFO file). - pub infinite: bool, } impl<'a> Default for AvroReadOptions<'a> { @@ -352,7 +322,6 @@ impl<'a> Default for AvroReadOptions<'a> { schema: None, file_extension: DEFAULT_AVRO_EXTENSION, table_partition_cols: vec![], - infinite: false, } } } @@ -367,12 +336,6 @@ impl<'a> AvroReadOptions<'a> { self } - /// Configure mark_infinite setting - pub fn mark_infinite(mut self, infinite: bool) -> Self { - self.infinite = infinite; - self - } - /// Specify schema to use for AVRO read pub fn schema(mut self, schema: &'a Schema) -> Self { self.schema = Some(schema); @@ -403,8 +366,6 @@ pub struct NdJsonReadOptions<'a> { pub infinite: bool, /// Indicates how the file is sorted pub file_sort_order: Vec>, - /// Setting controls how inserts to this file should be handled - pub insert_mode: ListingTableInsertMode, } impl<'a> Default for NdJsonReadOptions<'a> { @@ -417,7 +378,6 @@ impl<'a> Default for NdJsonReadOptions<'a> { file_compression_type: FileCompressionType::UNCOMPRESSED, infinite: false, file_sort_order: vec![], - insert_mode: ListingTableInsertMode::AppendToFile, } } } @@ -464,12 +424,6 @@ impl<'a> NdJsonReadOptions<'a> { self.file_sort_order = file_sort_order; self } - - /// Configure how insertions to this table should be handled - pub fn insert_mode(mut self, insert_mode: ListingTableInsertMode) -> Self { - self.insert_mode = insert_mode; - self - } } #[async_trait] @@ -493,21 +447,17 @@ pub trait ReadOptions<'a> { state: SessionState, table_path: ListingTableUrl, schema: Option<&'a Schema>, - infinite: bool, ) -> Result where 'a: 'async_trait, { - match (schema, infinite) { - (Some(s), _) => Ok(Arc::new(s.to_owned())), - (None, false) => Ok(self - .to_listing_options(config) - .infer_schema(&state, &table_path) - .await?), - (None, true) => { - plan_err!("Schema inference for infinite data sources is not supported.") - } + if let Some(s) = schema { + return Ok(Arc::new(s.to_owned())); } + + self.to_listing_options(config) + .infer_schema(&state, &table_path) + .await } } @@ -527,8 +477,6 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) - .with_infinite_source(self.infinite) - .with_insert_mode(self.insert_mode.clone()) } async fn get_resolved_schema( @@ -537,11 +485,12 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, self.infinite) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } +#[cfg(feature = "parquet")] #[async_trait] impl ReadOptions<'_> for ParquetReadOptions<'_> { fn to_listing_options(&self, config: &SessionConfig) -> ListingOptions { @@ -554,7 +503,6 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) - .with_insert_mode(self.insert_mode.clone()) } async fn get_resolved_schema( @@ -563,7 +511,7 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, false) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } @@ -579,9 +527,7 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { .with_file_extension(self.file_extension) .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) - .with_infinite_source(self.infinite) .with_file_sort_order(self.file_sort_order.clone()) - .with_insert_mode(self.insert_mode.clone()) } async fn get_resolved_schema( @@ -590,7 +536,7 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, self.infinite) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } @@ -604,7 +550,6 @@ impl ReadOptions<'_> for AvroReadOptions<'_> { .with_file_extension(self.file_extension) .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) - .with_infinite_source(self.infinite) } async fn get_resolved_schema( @@ -613,7 +558,7 @@ impl ReadOptions<'_> for AvroReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, self.infinite) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } @@ -635,7 +580,7 @@ impl ReadOptions<'_> for ArrowReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, false) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index ebdf3ea444b1..9729bfa163af 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -15,41 +15,48 @@ // specific language governing permissions and limitations // under the License. -//! Parquet format abstractions +//! [`ParquetFormat`]: Parquet [`FileFormat`] abstractions -use parquet::column::writer::ColumnCloseResult; +use arrow_array::RecordBatch; +use async_trait::async_trait; +use datafusion_common::stats::Precision; +use datafusion_physical_plan::metrics::MetricsSet; +use parquet::arrow::arrow_writer::{ + compute_leaves, get_column_writers, ArrowColumnChunk, ArrowColumnWriter, + ArrowLeafColumn, +}; use parquet::file::writer::SerializedFileWriter; -use rand::distributions::DistString; use std::any::Any; use std::fmt; use std::fmt::Debug; -use std::io::Write; use std::sync::Arc; use tokio::io::{AsyncWrite, AsyncWriteExt}; -use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; +use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio::task::{JoinHandle, JoinSet}; use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::statistics::{create_max_min_accs, get_col_stats}; use arrow::datatypes::SchemaRef; use arrow::datatypes::{Fields, Schema}; -use async_trait::async_trait; use bytes::{BufMut, BytesMut}; -use datafusion_common::{exec_err, not_impl_err, plan_err, DataFusionError, FileType}; +use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use futures::{StreamExt, TryStreamExt}; use hashbrown::HashMap; +use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; -use parquet::arrow::{parquet_to_arrow_schema, AsyncArrowWriter}; +use parquet::arrow::{ + arrow_to_parquet_schema, parquet_to_arrow_schema, AsyncArrowWriter, +}; use parquet::file::footer::{decode_footer, decode_metadata}; use parquet::file::metadata::ParquetMetaData; use parquet::file::properties::WriterProperties; use parquet::file::statistics::Statistics as ParquetStatistics; -use rand::distributions::Alphanumeric; -use super::write::{create_writer, AbortableWrite, FileWriterMode}; -use super::FileFormat; -use super::FileScanConfig; +use super::write::demux::start_demuxer_task; +use super::write::{create_writer, AbortableWrite, SharedBuffer}; +use super::{FileFormat, FileScanConfig}; use crate::arrow::array::{ BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, }; @@ -57,10 +64,8 @@ use crate::arrow::datatypes::DataType; use crate::config::ConfigOptions; use crate::datasource::physical_plan::{ - FileGroupDisplay, FileMeta, FileSinkConfig, ParquetExec, SchemaAdapter, + FileGroupDisplay, FileSinkConfig, ParquetExec, SchemaAdapter, }; - -use crate::datasource::{create_max_min_accs, get_col_stats}; use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; @@ -70,6 +75,17 @@ use crate::physical_plan::{ Statistics, }; +/// Size of the buffer for [`AsyncArrowWriter`]. +const PARQUET_WRITER_BUFFER_SIZE: usize = 10485760; + +/// Initial writing buffer size. Note this is just a size hint for efficiency. It +/// will grow beyond the set value if needed. +const INITIAL_BUFFER_BYTES: usize = 1048576; + +/// When writing parquet files in parallel, if the buffered Parquet data exceeds +/// this size, it is flushed to object store +const BUFFER_FLUSH_BYTES: usize = 1024000; + /// The Apache Parquet `FileFormat` implementation /// /// Note it is recommended these are instead configured on the [`ConfigOptions`] @@ -158,6 +174,16 @@ fn clear_metadata( }) } +async fn fetch_schema_with_location( + store: &dyn ObjectStore, + file: &ObjectMeta, + metadata_size_hint: Option, +) -> Result<(Path, Schema)> { + let loc_path = file.location.clone(); + let schema = fetch_schema(store, file, metadata_size_hint).await?; + Ok((loc_path, schema)) +} + #[async_trait] impl FileFormat for ParquetFormat { fn as_any(&self) -> &dyn Any { @@ -170,13 +196,32 @@ impl FileFormat for ParquetFormat { store: &Arc, objects: &[ObjectMeta], ) -> Result { - let schemas: Vec<_> = futures::stream::iter(objects) - .map(|object| fetch_schema(store.as_ref(), object, self.metadata_size_hint)) + let mut schemas: Vec<_> = futures::stream::iter(objects) + .map(|object| { + fetch_schema_with_location( + store.as_ref(), + object, + self.metadata_size_hint, + ) + }) .boxed() // Workaround https://github.com/rust-lang/rust/issues/64552 .buffered(state.config_options().execution.meta_fetch_concurrency) .try_collect() .await?; + // Schema inference adds fields based the order they are seen + // which depends on the order the files are processed. For some + // object stores (like local file systems) the order returned from list + // is not deterministic. Thus, to ensure deterministic schema inference + // sort the files first. + // https://github.com/apache/arrow-datafusion/pull/6629 + schemas.sort_by(|(location1, _), (location2, _)| location1.cmp(location2)); + + let schemas = schemas + .into_iter() + .map(|(_, schema)| schema) + .collect::>(); + let schema = if self.skip_metadata(state.config_options()) { Schema::try_merge(clear_metadata(schemas)) } else { @@ -229,6 +274,7 @@ impl FileFormat for ParquetFormat { input: Arc, _state: &SessionState, conf: FileSinkConfig, + order_requirements: Option>, ) -> Result> { if conf.overwrite { return not_impl_err!("Overwrites are not implemented yet for Parquet"); @@ -237,7 +283,12 @@ impl FileFormat for ParquetFormat { let sink_schema = conf.output_schema().clone(); let sink = Arc::new(ParquetSink::new(conf)); - Ok(Arc::new(FileSinkExec::new(input, sink, sink_schema)) as _) + Ok(Arc::new(FileSinkExec::new( + input, + sink, + sink_schema, + order_requirements, + )) as _) } fn file_type(&self) -> FileType { @@ -508,7 +559,7 @@ async fn fetch_statistics( let mut num_rows = 0; let mut total_byte_size = 0; - let mut null_counts = vec![0; num_fields]; + let mut null_counts = vec![Precision::Exact(0); num_fields]; let mut has_statistics = false; let schema_adapter = SchemaAdapter::new(table_schema.clone()); @@ -534,7 +585,7 @@ async fn fetch_statistics( schema_adapter.map_column_index(table_idx, &file_schema) { if let Some((null_count, stats)) = column_stats.get(&file_idx) { - *null_cnt += *null_count as usize; + *null_cnt = null_cnt.add(&Precision::Exact(*null_count as usize)); summarize_min_max( &mut max_values, &mut min_values, @@ -548,35 +599,29 @@ async fn fetch_statistics( min_values[table_idx] = None; } } else { - *null_cnt += num_rows as usize; + *null_cnt = null_cnt.add(&Precision::Exact(num_rows as usize)); } } } } let column_stats = if has_statistics { - Some(get_col_stats( - &table_schema, - null_counts, - &mut max_values, - &mut min_values, - )) + get_col_stats(&table_schema, null_counts, &mut max_values, &mut min_values) } else { - None + Statistics::unknown_column(&table_schema) }; let statistics = Statistics { - num_rows: Some(num_rows as usize), - total_byte_size: Some(total_byte_size as usize), + num_rows: Precision::Exact(num_rows as usize), + total_byte_size: Precision::Exact(total_byte_size as usize), column_statistics: column_stats, - is_exact: true, }; Ok(statistics) } /// Implements [`DataSink`] for writing to a parquet file. -struct ParquetSink { +pub struct ParquetSink { /// Config options for writing data config: FileSinkConfig, } @@ -591,11 +636,7 @@ impl DisplayAs for ParquetSink { fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!( - f, - "ParquetSink(writer_mode={:?}, file_groups=", - self.config.writer_mode - )?; + write!(f, "ParquetSink(file_groups=",)?; FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; write!(f, ")") } @@ -604,159 +645,79 @@ impl DisplayAs for ParquetSink { } impl ParquetSink { - fn new(config: FileSinkConfig) -> Self { + /// Create from config. + pub fn new(config: FileSinkConfig) -> Self { Self { config } } + /// Retrieve the inner [`FileSinkConfig`]. + pub fn config(&self) -> &FileSinkConfig { + &self.config + } + /// Converts table schema to writer schema, which may differ in the case + /// of hive style partitioning where some columns are removed from the + /// underlying files. + fn get_writer_schema(&self) -> Arc { + if !self.config.table_partition_cols.is_empty() { + let schema = self.config.output_schema(); + let partition_names: Vec<_> = self + .config + .table_partition_cols + .iter() + .map(|(s, _)| s) + .collect(); + Arc::new(Schema::new( + schema + .fields() + .iter() + .filter(|f| !partition_names.contains(&f.name())) + .map(|f| (**f).clone()) + .collect::>(), + )) + } else { + self.config.output_schema().clone() + } + } + /// Creates an AsyncArrowWriter which serializes a parquet file to an ObjectStore /// AsyncArrowWriters are used when individual parquet file serialization is not parallelized async fn create_async_arrow_writer( &self, - file_meta: FileMeta, + location: &Path, object_store: Arc, parquet_props: WriterProperties, ) -> Result< AsyncArrowWriter>, > { - let object = &file_meta.object_meta; - match self.config.writer_mode { - FileWriterMode::Append => { - plan_err!( - "Appending to Parquet files is not supported by the file format!" - ) - } - FileWriterMode::Put => { - not_impl_err!("FileWriterMode::Put is not implemented for ParquetSink") - } - FileWriterMode::PutMultipart => { - let (_, multipart_writer) = object_store - .put_multipart(&object.location) - .await - .map_err(DataFusionError::ObjectStore)?; - let writer = AsyncArrowWriter::try_new( - multipart_writer, - self.config.output_schema.clone(), - 10485760, - Some(parquet_props), - )?; - Ok(writer) - } - } + let (_, multipart_writer) = object_store + .put_multipart(location) + .await + .map_err(DataFusionError::ObjectStore)?; + let writer = AsyncArrowWriter::try_new( + multipart_writer, + self.get_writer_schema(), + PARQUET_WRITER_BUFFER_SIZE, + Some(parquet_props), + )?; + Ok(writer) } +} - /// Creates an AsyncArrowWriter for each partition to be written out - /// AsyncArrowWriters are used when individual parquet file serialization is not parallelized - async fn create_all_async_arrow_writers( - &self, - num_partitions: usize, - parquet_props: &WriterProperties, - object_store: Arc, - ) -> Result< - Vec>>, - > { - // Construct writer for each file group - let mut writers = vec![]; - match self.config.writer_mode { - FileWriterMode::Append => { - return plan_err!( - "Parquet format does not support appending to existing file!" - ) - } - FileWriterMode::Put => { - return not_impl_err!("Put Mode is not implemented for ParquetSink yet") - } - FileWriterMode::PutMultipart => { - // Currently assuming only 1 partition path (i.e. not hive-style partitioning on a column) - let base_path = &self.config.table_paths[0]; - match self.config.single_file_output { - false => { - // Uniquely identify this batch of files with a random string, to prevent collisions overwriting files - let write_id = - Alphanumeric.sample_string(&mut rand::thread_rng(), 16); - for part_idx in 0..num_partitions { - let file_path = base_path - .prefix() - .child(format!("{}_{}.parquet", write_id, part_idx)); - let object_meta = ObjectMeta { - location: file_path, - last_modified: chrono::offset::Utc::now(), - size: 0, - e_tag: None, - }; - let writer = self - .create_async_arrow_writer( - object_meta.into(), - object_store.clone(), - parquet_props.clone(), - ) - .await?; - writers.push(writer); - } - } - true => { - let file_path = base_path.prefix(); - let object_meta = ObjectMeta { - location: file_path.clone(), - last_modified: chrono::offset::Utc::now(), - size: 0, - e_tag: None, - }; - let writer = self - .create_async_arrow_writer( - object_meta.into(), - object_store.clone(), - parquet_props.clone(), - ) - .await?; - writers.push(writer); - } - } - } - } - - Ok(writers) +#[async_trait] +impl DataSink for ParquetSink { + fn as_any(&self) -> &dyn Any { + self } - /// Creates an object store writer for each output partition - /// This is used when parallelizing individual parquet file writes. - async fn create_object_store_writers( - &self, - num_partitions: usize, - object_store: Arc, - ) -> Result>>> { - let mut writers = Vec::new(); - - for _ in 0..num_partitions { - let file_path = self.config.table_paths[0].prefix(); - let object_meta = ObjectMeta { - location: file_path.clone(), - last_modified: chrono::offset::Utc::now(), - size: 0, - e_tag: None, - }; - writers.push( - create_writer( - FileWriterMode::PutMultipart, - FileCompressionType::UNCOMPRESSED, - object_meta.into(), - object_store.clone(), - ) - .await?, - ); - } - - Ok(writers) + fn metrics(&self) -> Option { + None } -} -#[async_trait] -impl DataSink for ParquetSink { async fn write_all( &self, - mut data: Vec, + data: SendableRecordBatchStream, context: &Arc, ) -> Result { - let num_partitions = data.len(); let parquet_props = self .config .file_type_writer_options @@ -767,241 +728,328 @@ impl DataSink for ParquetSink { .runtime_env() .object_store(&self.config.object_store_url)?; - let mut row_count = 0; + let parquet_opts = &context.session_config().options().execution.parquet; + let allow_single_file_parallelism = parquet_opts.allow_single_file_parallelism; - let allow_single_file_parallelism = context - .session_config() - .options() - .execution - .parquet - .allow_single_file_parallelism; - - match self.config.single_file_output { - false => { - let writers = self - .create_all_async_arrow_writers( - num_partitions, - parquet_props, + let part_col = if !self.config.table_partition_cols.is_empty() { + Some(self.config.table_partition_cols.clone()) + } else { + None + }; + + let parallel_options = ParallelParquetWriterOptions { + max_parallel_row_groups: parquet_opts.maximum_parallel_row_group_writers, + max_buffered_record_batches_per_stream: parquet_opts + .maximum_buffered_record_batches_per_stream, + }; + + let (demux_task, mut file_stream_rx) = start_demuxer_task( + data, + context, + part_col, + self.config.table_paths[0].clone(), + "parquet".into(), + self.config.single_file_output, + ); + + let mut file_write_tasks: JoinSet> = + JoinSet::new(); + while let Some((path, mut rx)) = file_stream_rx.recv().await { + if !allow_single_file_parallelism { + let mut writer = self + .create_async_arrow_writer( + &path, object_store.clone(), + parquet_props.clone(), ) .await?; - // TODO parallelize individual parquet serialization when already outputting multiple parquet files - // e.g. if outputting 2 parquet files on a system with 32 threads, spawn 16 tasks for each individual - // file to be serialized. - row_count = output_multiple_parquet_files(writers, data).await?; - } - true => { - if !allow_single_file_parallelism || data.len() <= 1 { - let mut writer = self - .create_all_async_arrow_writers( - num_partitions, - parquet_props, - object_store.clone(), - ) - .await? - .remove(0); - for data_stream in data.iter_mut() { - while let Some(batch) = data_stream.next().await.transpose()? { - row_count += batch.num_rows(); - writer.write(&batch).await?; - } + file_write_tasks.spawn(async move { + let mut row_count = 0; + while let Some(batch) = rx.recv().await { + row_count += batch.num_rows(); + writer.write(&batch).await?; } - writer.close().await?; - } else { - let object_store_writer = self - .create_object_store_writers(1, object_store) - .await? - .remove(0); - row_count = output_single_parquet_file_parallelized( - object_store_writer, - data, - self.config.output_schema.clone(), - parquet_props, + Ok(row_count) + }); + } else { + let writer = create_writer( + // Parquet files as a whole are never compressed, since they + // manage compressed blocks themselves. + FileCompressionType::UNCOMPRESSED, + &path, + object_store.clone(), + ) + .await?; + let schema = self.get_writer_schema(); + let props = parquet_props.clone(); + let parallel_options_clone = parallel_options.clone(); + file_write_tasks.spawn(async move { + output_single_parquet_file_parallelized( + writer, + rx, + schema, + &props, + parallel_options_clone, ) - .await?; + .await + }); + } + } + + let mut row_count = 0; + while let Some(result) = file_write_tasks.join_next().await { + match result { + Ok(r) => { + row_count += r?; + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } } } } + match demux_task.await { + Ok(r) => r?, + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } Ok(row_count as u64) } } -/// This is the return type when joining subtasks which are serializing parquet files -/// into memory buffers. The first part of the tuple is the parquet bytes and the -/// second is how many rows were written into the file. -type ParquetFileSerializedResult = Result<(Vec, usize), DataFusionError>; +/// Consumes a stream of [ArrowLeafColumn] via a channel and serializes them using an [ArrowColumnWriter] +/// Once the channel is exhausted, returns the ArrowColumnWriter. +async fn column_serializer_task( + mut rx: Receiver, + mut writer: ArrowColumnWriter, +) -> Result { + while let Some(col) = rx.recv().await { + writer.write(&col)?; + } + Ok(writer) +} -/// Parallelizes the serialization of a single parquet file, by first serializing N -/// independent RecordBatch streams in parallel to parquet files in memory. Another -/// task then stitches these independent files back together and streams this large -/// single parquet file to an ObjectStore in multiple parts. -async fn output_single_parquet_file_parallelized( - mut object_store_writer: AbortableWrite>, - mut data: Vec, - output_schema: Arc, - parquet_props: &WriterProperties, -) -> Result { - let mut row_count = 0; - // TODO decrease parallelism / buffering: - // https://github.com/apache/arrow-datafusion/issues/7591 - let parallelism = data.len(); - let mut join_handles: Vec> = - Vec::with_capacity(parallelism); - for _ in 0..parallelism { - let buffer: Vec = Vec::new(); - let mut writer = parquet::arrow::arrow_writer::ArrowWriter::try_new( - buffer, - output_schema.clone(), - Some(parquet_props.clone()), - )?; - let mut data_stream = data.remove(0); - join_handles.push(tokio::spawn(async move { - let mut inner_row_count = 0; - while let Some(batch) = data_stream.next().await.transpose()? { - inner_row_count += batch.num_rows(); - writer.write(&batch)?; - } - let out = writer.into_inner()?; - Ok((out, inner_row_count)) - })) +type ColumnJoinHandle = JoinHandle>; +type ColSender = Sender; +/// Spawns a parallel serialization task for each column +/// Returns join handles for each columns serialization task along with a send channel +/// to send arrow arrays to each serialization task. +fn spawn_column_parallel_row_group_writer( + schema: Arc, + parquet_props: Arc, + max_buffer_size: usize, +) -> Result<(Vec, Vec)> { + let schema_desc = arrow_to_parquet_schema(&schema)?; + let col_writers = get_column_writers(&schema_desc, &parquet_props, &schema)?; + let num_columns = col_writers.len(); + + let mut col_writer_handles = Vec::with_capacity(num_columns); + let mut col_array_channels = Vec::with_capacity(num_columns); + for writer in col_writers.into_iter() { + // Buffer size of this channel limits the number of arrays queued up for column level serialization + let (send_array, recieve_array) = + mpsc::channel::(max_buffer_size); + col_array_channels.push(send_array); + col_writer_handles + .push(tokio::spawn(column_serializer_task(recieve_array, writer))) } - let mut writer = None; - let endpoints: (UnboundedSender>, UnboundedReceiver>) = - tokio::sync::mpsc::unbounded_channel(); - let (tx, mut rx) = endpoints; - let writer_join_handle: JoinHandle< - Result< - AbortableWrite>, - DataFusionError, - >, - > = tokio::task::spawn(async move { - while let Some(data) = rx.recv().await { - // TODO write incrementally - // https://github.com/apache/arrow-datafusion/issues/7591 - object_store_writer.write_all(data.as_slice()).await?; + Ok((col_writer_handles, col_array_channels)) +} + +/// Settings related to writing parquet files in parallel +#[derive(Clone)] +struct ParallelParquetWriterOptions { + max_parallel_row_groups: usize, + max_buffered_record_batches_per_stream: usize, +} + +/// This is the return type of calling [ArrowColumnWriter].close() on each column +/// i.e. the Vec of encoded columns which can be appended to a row group +type RBStreamSerializeResult = Result<(Vec, usize)>; + +/// Sends the ArrowArrays in passed [RecordBatch] through the channels to their respective +/// parallel column serializers. +async fn send_arrays_to_col_writers( + col_array_channels: &[ColSender], + rb: &RecordBatch, + schema: Arc, +) -> Result<()> { + for (tx, array, field) in col_array_channels + .iter() + .zip(rb.columns()) + .zip(schema.fields()) + .map(|((a, b), c)| (a, b, c)) + { + for c in compute_leaves(field, array)? { + tx.send(c).await.map_err(|_| { + DataFusionError::Internal("Unable to send array to writer!".into()) + })?; } - Ok(object_store_writer) - }); - let merged_buff = SharedBuffer::new(1048576); - for handle in join_handles { - let join_result = handle.await; - match join_result { - Ok(result) => { - let (out, num_rows) = result?; - let reader = bytes::Bytes::from(out); - row_count += num_rows; - //let reader = File::open(buffer)?; - let metadata = parquet::file::footer::parse_metadata(&reader)?; - let schema = metadata.file_metadata().schema(); - writer = match writer { - Some(writer) => Some(writer), - None => Some(SerializedFileWriter::new( - merged_buff.clone(), - Arc::new(schema.clone()), - Arc::new(parquet_props.clone()), - )?), - }; + } - match &mut writer{ - Some(w) => { - // Note: cannot use .await within this loop as RowGroupMetaData is not Send - // Instead, use a non-blocking channel to send bytes to separate worker - // which will write to ObjectStore. - for rg in metadata.row_groups() { - let mut rg_out = w.next_row_group()?; - for column in rg.columns() { - let result = ColumnCloseResult { - bytes_written: column.compressed_size() as _, - rows_written: rg.num_rows() as _, - metadata: column.clone(), - // TODO need to populate the indexes when writing final file - // see https://github.com/apache/arrow-datafusion/issues/7589 - bloom_filter: None, - column_index: None, - offset_index: None, - }; - rg_out.append_column(&reader, result)?; - let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); - if buff_to_flush.len() > 1024000{ - let bytes: Vec = buff_to_flush.drain(..).collect(); - tx.send(bytes).map_err(|_| DataFusionError::Execution("Failed to send bytes to ObjectStore writer".into()))?; - - } - } - rg_out.close()?; - let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); - if buff_to_flush.len() > 1024000{ - let bytes: Vec = buff_to_flush.drain(..).collect(); - tx.send(bytes).map_err(|_| DataFusionError::Execution("Failed to send bytes to ObjectStore writer".into()))?; - } - } - }, - None => unreachable!("Parquet writer should always be initialized in first iteration of loop!") + Ok(()) +} + +/// Spawns a tokio task which joins the parallel column writer tasks, +/// and finalizes the row group. +fn spawn_rg_join_and_finalize_task( + column_writer_handles: Vec>>, + rg_rows: usize, +) -> JoinHandle { + tokio::spawn(async move { + let num_cols = column_writer_handles.len(); + let mut finalized_rg = Vec::with_capacity(num_cols); + for handle in column_writer_handles.into_iter() { + match handle.await { + Ok(r) => { + let w = r?; + finalized_rg.push(w.close()?); } - } - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()) + } else { + unreachable!() + } } } } - } - let inner_writer = writer.unwrap().into_inner()?; - let final_buff = inner_writer.buffer.try_lock().unwrap(); - // Explicitly drop tx to signal to rx we are done sending data - drop(tx); + Ok((finalized_rg, rg_rows)) + }) +} - let mut object_store_writer = match writer_join_handle.await { - Ok(r) => r?, - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()) +/// This task coordinates the serialization of a parquet file in parallel. +/// As the query produces RecordBatches, these are written to a RowGroup +/// via parallel [ArrowColumnWriter] tasks. Once the desired max rows per +/// row group is reached, the parallel tasks are joined on another separate task +/// and sent to a concatenation task. This task immediately continues to work +/// on the next row group in parallel. So, parquet serialization is parallelized +/// accross both columns and row_groups, with a theoretical max number of parallel tasks +/// given by n_columns * num_row_groups. +fn spawn_parquet_parallel_serialization_task( + mut data: Receiver, + serialize_tx: Sender>, + schema: Arc, + writer_props: Arc, + parallel_options: ParallelParquetWriterOptions, +) -> JoinHandle> { + tokio::spawn(async move { + let max_buffer_rb = parallel_options.max_buffered_record_batches_per_stream; + let max_row_group_rows = writer_props.max_row_group_size(); + let (mut column_writer_handles, mut col_array_channels) = + spawn_column_parallel_row_group_writer( + schema.clone(), + writer_props.clone(), + max_buffer_rb, + )?; + let mut current_rg_rows = 0; + + while let Some(rb) = data.recv().await { + if current_rg_rows + rb.num_rows() < max_row_group_rows { + send_arrays_to_col_writers(&col_array_channels, &rb, schema.clone()) + .await?; + current_rg_rows += rb.num_rows(); } else { - unreachable!() + let rows_left = max_row_group_rows - current_rg_rows; + let a = rb.slice(0, rows_left); + send_arrays_to_col_writers(&col_array_channels, &a, schema.clone()) + .await?; + + // Signal the parallel column writers that the RowGroup is done, join and finalize RowGroup + // on a separate task, so that we can immediately start on the next RG before waiting + // for the current one to finish. + drop(col_array_channels); + let finalize_rg_task = spawn_rg_join_and_finalize_task( + column_writer_handles, + max_row_group_rows, + ); + + serialize_tx.send(finalize_rg_task).await.map_err(|_| { + DataFusionError::Internal( + "Unable to send closed RG to concat task!".into(), + ) + })?; + + let b = rb.slice(rows_left, rb.num_rows() - rows_left); + (column_writer_handles, col_array_channels) = + spawn_column_parallel_row_group_writer( + schema.clone(), + writer_props.clone(), + max_buffer_rb, + )?; + send_arrays_to_col_writers(&col_array_channels, &b, schema.clone()) + .await?; + current_rg_rows = b.num_rows(); } } - }; - object_store_writer.write_all(final_buff.as_slice()).await?; - object_store_writer.shutdown().await?; - println!("done!"); - Ok(row_count) + drop(col_array_channels); + // Handle leftover rows as final rowgroup, which may be smaller than max_row_group_rows + if current_rg_rows > 0 { + let finalize_rg_task = + spawn_rg_join_and_finalize_task(column_writer_handles, current_rg_rows); + + serialize_tx.send(finalize_rg_task).await.map_err(|_| { + DataFusionError::Internal( + "Unable to send closed RG to concat task!".into(), + ) + })?; + } + + Ok(()) + }) } -/// Serializes multiple parquet files independently in parallel from different RecordBatch streams. -/// AsyncArrowWriter is used to coordinate serialization and MultiPart puts to ObjectStore -/// Only a single CPU thread is used to serialize each individual parquet file, so write speed and overall -/// CPU utilization is dependent on the number of output files. -async fn output_multiple_parquet_files( - writers: Vec< - AsyncArrowWriter>, - >, - data: Vec, +/// Consume RowGroups serialized by other parallel tasks and concatenate them in +/// to the final parquet file, while flushing finalized bytes to an [ObjectStore] +async fn concatenate_parallel_row_groups( + mut serialize_rx: Receiver>, + schema: Arc, + writer_props: Arc, + mut object_store_writer: AbortableWrite>, ) -> Result { + let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES); + + let schema_desc = arrow_to_parquet_schema(schema.as_ref())?; + let mut parquet_writer = SerializedFileWriter::new( + merged_buff.clone(), + schema_desc.root_schema_ptr(), + writer_props, + )?; + let mut row_count = 0; - let mut join_set: JoinSet> = JoinSet::new(); - for (mut data_stream, mut writer) in data.into_iter().zip(writers.into_iter()) { - join_set.spawn(async move { - let mut cnt = 0; - while let Some(batch) = data_stream.next().await.transpose()? { - cnt += batch.num_rows(); - writer.write(&batch).await?; + + while let Some(handle) = serialize_rx.recv().await { + let join_result = handle.await; + match join_result { + Ok(result) => { + let mut rg_out = parquet_writer.next_row_group()?; + let (serialized_columns, cnt) = result?; + row_count += cnt; + for chunk in serialized_columns { + chunk.append_to_row_group(&mut rg_out)?; + let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { + object_store_writer + .write_all(buff_to_flush.as_slice()) + .await?; + buff_to_flush.clear(); + } + } + rg_out.close()?; } - writer.close().await?; - Ok(cnt) - }); - } - while let Some(result) = join_set.join_next().await { - match result { - Ok(res) => { - row_count += res?; - } // propagate DataFusion error Err(e) => { if e.is_panic() { std::panic::resume_unwind(e.into_panic()); @@ -1012,38 +1060,60 @@ async fn output_multiple_parquet_files( } } - Ok(row_count) -} + let inner_writer = parquet_writer.into_inner()?; + let final_buff = inner_writer.buffer.try_lock().unwrap(); -/// A buffer with interior mutability shared by the SerializedFileWriter and -/// ObjectStore writer -#[derive(Clone)] -struct SharedBuffer { - /// The inner buffer for reading and writing - /// - /// The lock is used to obtain internal mutability, so no worry about the - /// lock contention. - buffer: Arc>>, -} + object_store_writer.write_all(final_buff.as_slice()).await?; + object_store_writer.shutdown().await?; -impl SharedBuffer { - pub fn new(capacity: usize) -> Self { - Self { - buffer: Arc::new(futures::lock::Mutex::new(Vec::with_capacity(capacity))), - } - } + Ok(row_count) } -impl Write for SharedBuffer { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - let mut buffer = self.buffer.try_lock().unwrap(); - Write::write(&mut *buffer, buf) - } +/// Parallelizes the serialization of a single parquet file, by first serializing N +/// independent RecordBatch streams in parallel to RowGroups in memory. Another +/// task then stitches these independent RowGroups together and streams this large +/// single parquet file to an ObjectStore in multiple parts. +async fn output_single_parquet_file_parallelized( + object_store_writer: AbortableWrite>, + data: Receiver, + output_schema: Arc, + parquet_props: &WriterProperties, + parallel_options: ParallelParquetWriterOptions, +) -> Result { + let max_rowgroups = parallel_options.max_parallel_row_groups; + // Buffer size of this channel limits maximum number of RowGroups being worked on in parallel + let (serialize_tx, serialize_rx) = + mpsc::channel::>(max_rowgroups); + + let arc_props = Arc::new(parquet_props.clone()); + let launch_serialization_task = spawn_parquet_parallel_serialization_task( + data, + serialize_tx, + output_schema.clone(), + arc_props.clone(), + parallel_options, + ); + let row_count = concatenate_parallel_row_groups( + serialize_rx, + output_schema.clone(), + arc_props.clone(), + object_store_writer, + ) + .await?; + + match launch_serialization_task.await { + Ok(Ok(_)) => (), + Ok(Err(e)) => return Err(e), + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()) + } else { + unreachable!() + } + } + }; - fn flush(&mut self) -> std::io::Result<()> { - let mut buffer = self.buffer.try_lock().unwrap(); - Write::flush(&mut *buffer) - } + Ok(row_count) } #[cfg(test)] @@ -1067,12 +1137,21 @@ pub(crate) mod test_util { batches: Vec, multi_page: bool, ) -> Result<(Vec, Vec)> { + // we need the tmp files to be sorted as some tests rely on the how the returning files are ordered + // https://github.com/apache/arrow-datafusion/pull/6629 + let tmp_files = { + let mut tmp_files: Vec<_> = (0..batches.len()) + .map(|_| NamedTempFile::new().expect("creating temp file")) + .collect(); + tmp_files.sort_by(|a, b| a.path().cmp(b.path())); + tmp_files + }; + // Each batch writes to their own file let files: Vec<_> = batches .into_iter() - .map(|batch| { - let mut output = NamedTempFile::new().expect("creating temp file"); - + .zip(tmp_files.into_iter()) + .map(|(batch, mut output)| { let builder = WriterProperties::builder(); let props = if multi_page { builder.set_data_page_row_count_limit(ROWS_PER_PAGE) @@ -1098,10 +1177,11 @@ pub(crate) mod test_util { .collect(); let meta: Vec<_> = files.iter().map(local_unpartitioned_file).collect(); + Ok((meta, files)) } - //// write batches chunk_size rows at a time + /// write batches chunk_size rows at a time fn write_in_chunks( writer: &mut ArrowWriter, batch: &RecordBatch, @@ -1142,7 +1222,9 @@ mod tests { use log::error; use object_store::local::LocalFileSystem; use object_store::path::Path; - use object_store::{GetOptions, GetResult, ListResult, MultipartId}; + use object_store::{ + GetOptions, GetResult, ListResult, MultipartId, PutOptions, PutResult, + }; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::ParquetRecordBatchStreamBuilder; use parquet::file::metadata::{ParquetColumnIndex, ParquetOffsetIndex}; @@ -1171,20 +1253,62 @@ mod tests { let stats = fetch_statistics(store.as_ref(), schema.clone(), &meta[0], None).await?; - assert_eq!(stats.num_rows, Some(3)); - let c1_stats = &stats.column_statistics.as_ref().expect("missing c1 stats")[0]; - let c2_stats = &stats.column_statistics.as_ref().expect("missing c2 stats")[1]; - assert_eq!(c1_stats.null_count, Some(1)); - assert_eq!(c2_stats.null_count, Some(3)); + assert_eq!(stats.num_rows, Precision::Exact(3)); + let c1_stats = &stats.column_statistics[0]; + let c2_stats = &stats.column_statistics[1]; + assert_eq!(c1_stats.null_count, Precision::Exact(1)); + assert_eq!(c2_stats.null_count, Precision::Exact(3)); let stats = fetch_statistics(store.as_ref(), schema, &meta[1], None).await?; - assert_eq!(stats.num_rows, Some(3)); - let c1_stats = &stats.column_statistics.as_ref().expect("missing c1 stats")[0]; - let c2_stats = &stats.column_statistics.as_ref().expect("missing c2 stats")[1]; - assert_eq!(c1_stats.null_count, Some(3)); - assert_eq!(c2_stats.null_count, Some(1)); - assert_eq!(c2_stats.max_value, Some(ScalarValue::Int64(Some(2)))); - assert_eq!(c2_stats.min_value, Some(ScalarValue::Int64(Some(1)))); + assert_eq!(stats.num_rows, Precision::Exact(3)); + let c1_stats = &stats.column_statistics[0]; + let c2_stats = &stats.column_statistics[1]; + assert_eq!(c1_stats.null_count, Precision::Exact(3)); + assert_eq!(c2_stats.null_count, Precision::Exact(1)); + assert_eq!( + c2_stats.max_value, + Precision::Exact(ScalarValue::Int64(Some(2))) + ); + assert_eq!( + c2_stats.min_value, + Precision::Exact(ScalarValue::Int64(Some(1))) + ); + + Ok(()) + } + + #[tokio::test] + async fn is_schema_stable() -> Result<()> { + let c1: ArrayRef = + Arc::new(StringArray::from(vec![Some("Foo"), None, Some("bar")])); + + let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); + + let batch1 = + RecordBatch::try_from_iter(vec![("a", c1.clone()), ("b", c1.clone())]) + .unwrap(); + let batch2 = + RecordBatch::try_from_iter(vec![("c", c2.clone()), ("d", c2.clone())]) + .unwrap(); + + let store = Arc::new(LocalFileSystem::new()) as _; + let (meta, _files) = store_parquet(vec![batch1, batch2], false).await?; + + let session = SessionContext::new(); + let ctx = session.state(); + let format = ParquetFormat::default(); + let schema = format.infer_schema(&ctx, &store, &meta).await.unwrap(); + + let order: Vec<_> = ["a", "b", "c", "d"] + .into_iter() + .map(|i| i.to_string()) + .collect(); + let coll: Vec<_> = schema + .all_fields() + .into_iter() + .map(|i| i.name().to_string()) + .collect(); + assert_eq!(coll, order); Ok(()) } @@ -1220,7 +1344,12 @@ mod tests { #[async_trait] impl ObjectStore for RequestCountingObjectStore { - async fn put(&self, _location: &Path, _bytes: Bytes) -> object_store::Result<()> { + async fn put_opts( + &self, + _location: &Path, + _bytes: Bytes, + _opts: PutOptions, + ) -> object_store::Result { Err(object_store::Error::NotImplemented) } @@ -1257,12 +1386,13 @@ mod tests { Err(object_store::Error::NotImplemented) } - async fn list( + fn list( &self, _prefix: Option<&Path>, - ) -> object_store::Result>> - { - Err(object_store::Error::NotImplemented) + ) -> BoxStream<'_, object_store::Result> { + Box::pin(futures::stream::once(async { + Err(object_store::Error::NotImplemented) + })) } async fn list_with_delimiter( @@ -1320,11 +1450,11 @@ mod tests { fetch_statistics(store.upcast().as_ref(), schema.clone(), &meta[0], Some(9)) .await?; - assert_eq!(stats.num_rows, Some(3)); - let c1_stats = &stats.column_statistics.as_ref().expect("missing c1 stats")[0]; - let c2_stats = &stats.column_statistics.as_ref().expect("missing c2 stats")[1]; - assert_eq!(c1_stats.null_count, Some(1)); - assert_eq!(c2_stats.null_count, Some(3)); + assert_eq!(stats.num_rows, Precision::Exact(3)); + let c1_stats = &stats.column_statistics[0]; + let c2_stats = &stats.column_statistics[1]; + assert_eq!(c1_stats.null_count, Precision::Exact(1)); + assert_eq!(c2_stats.null_count, Precision::Exact(3)); let store = Arc::new(RequestCountingObjectStore::new(Arc::new( LocalFileSystem::new(), @@ -1353,11 +1483,11 @@ mod tests { ) .await?; - assert_eq!(stats.num_rows, Some(3)); - let c1_stats = &stats.column_statistics.as_ref().expect("missing c1 stats")[0]; - let c2_stats = &stats.column_statistics.as_ref().expect("missing c2 stats")[1]; - assert_eq!(c1_stats.null_count, Some(1)); - assert_eq!(c2_stats.null_count, Some(3)); + assert_eq!(stats.num_rows, Precision::Exact(3)); + let c1_stats = &stats.column_statistics[0]; + let c2_stats = &stats.column_statistics[1]; + assert_eq!(c1_stats.null_count, Precision::Exact(1)); + assert_eq!(c2_stats.null_count, Precision::Exact(3)); let store = Arc::new(RequestCountingObjectStore::new(Arc::new( LocalFileSystem::new(), @@ -1378,7 +1508,7 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { let config = SessionConfig::new().with_batch_size(2); - let session_ctx = SessionContext::with_config(config); + let session_ctx = SessionContext::new_with_config(config); let state = session_ctx.state(); let task_ctx = state.task_ctx(); let projection = None; @@ -1397,8 +1527,8 @@ mod tests { assert_eq!(tt_batches, 4 /* 8/2 */); // test metadata - assert_eq!(exec.statistics().num_rows, Some(8)); - assert_eq!(exec.statistics().total_byte_size, Some(671)); + assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); + assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); Ok(()) } @@ -1406,7 +1536,7 @@ mod tests { #[tokio::test] async fn capture_bytes_scanned_metric() -> Result<()> { let config = SessionConfig::new().with_batch_size(2); - let session = SessionContext::with_config(config); + let session = SessionContext::new_with_config(config); let ctx = session.state(); // Read the full file @@ -1439,9 +1569,8 @@ mod tests { get_exec(&state, "alltypes_plain.parquet", projection, Some(1)).await?; // note: even if the limit is set, the executor rounds up to the batch size - assert_eq!(exec.statistics().num_rows, Some(8)); - assert_eq!(exec.statistics().total_byte_size, Some(671)); - assert!(exec.statistics().is_exact); + assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); + assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); @@ -1733,8 +1862,8 @@ mod tests { // there is only one row group in one file. assert_eq!(page_index.len(), 1); assert_eq!(offset_index.len(), 1); - let page_index = page_index.get(0).unwrap(); - let offset_index = offset_index.get(0).unwrap(); + let page_index = page_index.first().unwrap(); + let offset_index = offset_index.first().unwrap(); // 13 col in one row group assert_eq!(page_index.len(), 13); diff --git a/datafusion/core/src/datasource/file_format/write.rs b/datafusion/core/src/datasource/file_format/write.rs deleted file mode 100644 index 42d18eef634c..000000000000 --- a/datafusion/core/src/datasource/file_format/write.rs +++ /dev/null @@ -1,533 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Module containing helper methods/traits related to enabling -//! write support for the various file formats - -use std::io::Error; -use std::mem; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::physical_plan::FileMeta; -use crate::error::Result; -use crate::physical_plan::SendableRecordBatchStream; - -use arrow_array::RecordBatch; -use datafusion_common::{exec_err, internal_err, DataFusionError}; - -use async_trait::async_trait; -use bytes::Bytes; -use datafusion_execution::RecordBatchStream; -use futures::future::BoxFuture; -use futures::FutureExt; -use futures::{ready, StreamExt}; -use object_store::path::Path; -use object_store::{MultipartId, ObjectMeta, ObjectStore}; -use tokio::io::{AsyncWrite, AsyncWriteExt}; -use tokio::sync::mpsc; -use tokio::task::{JoinHandle, JoinSet}; - -/// `AsyncPutWriter` is an object that facilitates asynchronous writing to object stores. -/// It is specifically designed for the `object_store` crate's `put` method and sends -/// whole bytes at once when the buffer is flushed. -pub struct AsyncPutWriter { - /// Object metadata - object_meta: ObjectMeta, - /// A shared reference to the object store - store: Arc, - /// A buffer that stores the bytes to be sent - current_buffer: Vec, - /// Used for async handling in flush method - inner_state: AsyncPutState, -} - -impl AsyncPutWriter { - /// Constructor for the `AsyncPutWriter` object - pub fn new(object_meta: ObjectMeta, store: Arc) -> Self { - Self { - object_meta, - store, - current_buffer: vec![], - // The writer starts out in buffering mode - inner_state: AsyncPutState::Buffer, - } - } - - /// Separate implementation function that unpins the [`AsyncPutWriter`] so - /// that partial borrows work correctly - fn poll_shutdown_inner( - &mut self, - cx: &mut Context<'_>, - ) -> Poll> { - loop { - match &mut self.inner_state { - AsyncPutState::Buffer => { - // Convert the current buffer to bytes and take ownership of it - let bytes = Bytes::from(mem::take(&mut self.current_buffer)); - // Set the inner state to Put variant with the bytes - self.inner_state = AsyncPutState::Put { bytes } - } - AsyncPutState::Put { bytes } => { - // Send the bytes to the object store's put method - return Poll::Ready( - ready!(self - .store - .put(&self.object_meta.location, bytes.clone()) - .poll_unpin(cx)) - .map_err(Error::from), - ); - } - } - } - } -} - -/// An enum that represents the inner state of AsyncPut -enum AsyncPutState { - /// Building Bytes struct in this state - Buffer, - /// Data in the buffer is being sent to the object store - Put { bytes: Bytes }, -} - -impl AsyncWrite for AsyncPutWriter { - // Define the implementation of the AsyncWrite trait for the `AsyncPutWriter` struct - fn poll_write( - mut self: Pin<&mut Self>, - _: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - // Extend the current buffer with the incoming buffer - self.current_buffer.extend_from_slice(buf); - // Return a ready poll with the length of the incoming buffer - Poll::Ready(Ok(buf.len())) - } - - fn poll_flush( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll> { - // Return a ready poll with an empty result - Poll::Ready(Ok(())) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - // Call the poll_shutdown_inner method to handle the actual sending of data to the object store - self.poll_shutdown_inner(cx) - } -} - -/// Stores data needed during abortion of MultiPart writers -pub(crate) struct MultiPart { - /// A shared reference to the object store - store: Arc, - multipart_id: MultipartId, - location: Path, -} - -impl MultiPart { - /// Create a new `MultiPart` - pub fn new( - store: Arc, - multipart_id: MultipartId, - location: Path, - ) -> Self { - Self { - store, - multipart_id, - location, - } - } -} - -pub(crate) enum AbortMode { - Put, - Append, - MultiPart(MultiPart), -} - -/// A wrapper struct with abort method and writer -pub(crate) struct AbortableWrite { - writer: W, - mode: AbortMode, -} - -impl AbortableWrite { - /// Create a new `AbortableWrite` instance with the given writer, and write mode. - pub(crate) fn new(writer: W, mode: AbortMode) -> Self { - Self { writer, mode } - } - - /// handling of abort for different write modes - pub(crate) fn abort_writer(&self) -> Result>> { - match &self.mode { - AbortMode::Put => Ok(async { Ok(()) }.boxed()), - AbortMode::Append => exec_err!("Cannot abort in append mode"), - AbortMode::MultiPart(MultiPart { - store, - multipart_id, - location, - }) => { - let location = location.clone(); - let multipart_id = multipart_id.clone(); - let store = store.clone(); - Ok(Box::pin(async move { - store - .abort_multipart(&location, &multipart_id) - .await - .map_err(DataFusionError::ObjectStore) - })) - } - } - } -} - -impl AsyncWrite for AbortableWrite { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_write(cx, buf) - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_flush(cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_shutdown(cx) - } -} - -/// An enum that defines different file writer modes. -#[derive(Debug, Clone, Copy)] -pub enum FileWriterMode { - /// Data is appended to an existing file. - Append, - /// Data is written to a new file. - Put, - /// Data is written to a new file in multiple parts. - PutMultipart, -} -/// A trait that defines the methods required for a RecordBatch serializer. -#[async_trait] -pub trait BatchSerializer: Unpin + Send { - /// Asynchronously serializes a `RecordBatch` and returns the serialized bytes. - async fn serialize(&mut self, batch: RecordBatch) -> Result; - /// Duplicates self to support serializing multiple batches in parallel on multiple cores - fn duplicate(&mut self) -> Result> { - Err(DataFusionError::NotImplemented( - "Parallel serialization is not implemented for this file type".into(), - )) - } -} - -/// Returns an [`AbortableWrite`] which writes to the given object store location -/// with the specified compression -pub(crate) async fn create_writer( - writer_mode: FileWriterMode, - file_compression_type: FileCompressionType, - file_meta: FileMeta, - object_store: Arc, -) -> Result>> { - let object = &file_meta.object_meta; - match writer_mode { - // If the mode is append, call the store's append method and return wrapped in - // a boxed trait object. - FileWriterMode::Append => { - let writer = object_store - .append(&object.location) - .await - .map_err(DataFusionError::ObjectStore)?; - let writer = AbortableWrite::new( - file_compression_type.convert_async_writer(writer)?, - AbortMode::Append, - ); - Ok(writer) - } - // If the mode is put, create a new AsyncPut writer and return it wrapped in - // a boxed trait object - FileWriterMode::Put => { - let writer = Box::new(AsyncPutWriter::new(object.clone(), object_store)); - let writer = AbortableWrite::new( - file_compression_type.convert_async_writer(writer)?, - AbortMode::Put, - ); - Ok(writer) - } - // If the mode is put multipart, call the store's put_multipart method and - // return the writer wrapped in a boxed trait object. - FileWriterMode::PutMultipart => { - let (multipart_id, writer) = object_store - .put_multipart(&object.location) - .await - .map_err(DataFusionError::ObjectStore)?; - Ok(AbortableWrite::new( - file_compression_type.convert_async_writer(writer)?, - AbortMode::MultiPart(MultiPart::new( - object_store, - multipart_id, - object.location.clone(), - )), - )) - } - } -} - -type WriterType = AbortableWrite>; -type SerializerType = Box; - -/// Serializes a single data stream in parallel and writes to an ObjectStore -/// concurrently. Data order is preserved. In the event of an error, -/// the ObjectStore writer is returned to the caller in addition to an error, -/// so that the caller may handle aborting failed writes. -async fn serialize_rb_stream_to_object_store( - mut data_stream: Pin>, - mut serializer: Box, - mut writer: AbortableWrite>, - unbounded_input: bool, -) -> std::result::Result<(SerializerType, WriterType, u64), (WriterType, DataFusionError)> -{ - let (tx, mut rx) = - mpsc::channel::>>(100); - - let serialize_task = tokio::spawn(async move { - while let Some(maybe_batch) = data_stream.next().await { - match serializer.duplicate() { - Ok(mut serializer_clone) => { - let handle = tokio::spawn(async move { - let batch = maybe_batch?; - let num_rows = batch.num_rows(); - let bytes = serializer_clone.serialize(batch).await?; - Ok((num_rows, bytes)) - }); - tx.send(handle).await.map_err(|_| { - DataFusionError::Internal( - "Unknown error writing to object store".into(), - ) - })?; - if unbounded_input { - tokio::task::yield_now().await; - } - } - Err(_) => { - return Err(DataFusionError::Internal( - "Unknown error writing to object store".into(), - )) - } - } - } - Ok(serializer) - }); - - let mut row_count = 0; - while let Some(handle) = rx.recv().await { - match handle.await { - Ok(Ok((cnt, bytes))) => { - match writer.write_all(&bytes).await { - Ok(_) => (), - Err(e) => { - return Err(( - writer, - DataFusionError::Execution(format!( - "Error writing to object store: {e}" - )), - )) - } - }; - row_count += cnt; - } - Ok(Err(e)) => { - // Return the writer along with the error - return Err((writer, e)); - } - Err(e) => { - // Handle task panic or cancellation - return Err(( - writer, - DataFusionError::Execution(format!( - "Serialization task panicked or was cancelled: {e}" - )), - )); - } - } - } - - let serializer = match serialize_task.await { - Ok(Ok(serializer)) => serializer, - Ok(Err(e)) => return Err((writer, e)), - Err(_) => { - return Err(( - writer, - DataFusionError::Internal("Unknown error writing to object store".into()), - )) - } - }; - Ok((serializer, writer, row_count as u64)) -} - -/// Contains the common logic for serializing RecordBatches and -/// writing the resulting bytes to an ObjectStore. -/// Serialization is assumed to be stateless, i.e. -/// each RecordBatch can be serialized without any -/// dependency on the RecordBatches before or after. -pub(crate) async fn stateless_serialize_and_write_files( - data: Vec, - mut serializers: Vec, - mut writers: Vec, - single_file_output: bool, - unbounded_input: bool, -) -> Result { - if single_file_output && (serializers.len() != 1 || writers.len() != 1) { - return internal_err!("single_file_output is true, but got more than 1 writer!"); - } - let num_partitions = data.len(); - let num_writers = writers.len(); - if !single_file_output && (num_partitions != num_writers) { - return internal_err!("single_file_ouput is false, but did not get 1 writer for each output partition!"); - } - let mut row_count = 0; - // tracks if any writers encountered an error triggering the need to abort - let mut any_errors = false; - // tracks the specific error triggering abort - let mut triggering_error = None; - // tracks if any errors were encountered in the process of aborting writers. - // if true, we may not have a guarentee that all written data was cleaned up. - let mut any_abort_errors = false; - match single_file_output { - false => { - let mut join_set = JoinSet::new(); - for (data_stream, serializer, writer) in data - .into_iter() - .zip(serializers.into_iter()) - .zip(writers.into_iter()) - .map(|((a, b), c)| (a, b, c)) - { - join_set.spawn(async move { - serialize_rb_stream_to_object_store( - data_stream, - serializer, - writer, - unbounded_input, - ) - .await - }); - } - let mut finished_writers = Vec::with_capacity(num_writers); - while let Some(result) = join_set.join_next().await { - match result { - Ok(res) => match res { - Ok((_, writer, cnt)) => { - finished_writers.push(writer); - row_count += cnt; - } - Err((writer, e)) => { - finished_writers.push(writer); - any_errors = true; - triggering_error = Some(e); - } - }, - Err(e) => { - // Don't panic, instead try to clean up as many writers as possible. - // If we hit this code, ownership of a writer was not joined back to - // this thread, so we cannot clean it up (hence any_abort_errors is true) - any_errors = true; - any_abort_errors = true; - triggering_error = Some(DataFusionError::Internal(format!( - "Unexpected join error while serializing file {e}" - ))); - } - } - } - - // Finalize or abort writers as appropriate - for mut writer in finished_writers.into_iter() { - match any_errors { - true => { - let abort_result = writer.abort_writer(); - if abort_result.is_err() { - any_abort_errors = true; - } - } - false => { - writer.shutdown() - .await - .map_err(|_| DataFusionError::Internal("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!".into()))?; - } - } - } - } - true => { - let mut writer = writers.remove(0); - let mut serializer = serializers.remove(0); - let mut cnt; - for data_stream in data.into_iter() { - (serializer, writer, cnt) = match serialize_rb_stream_to_object_store( - data_stream, - serializer, - writer, - unbounded_input, - ) - .await - { - Ok((s, w, c)) => (s, w, c), - Err((w, e)) => { - any_errors = true; - triggering_error = Some(e); - writer = w; - break; - } - }; - row_count += cnt; - } - match any_errors { - true => { - let abort_result = writer.abort_writer(); - if abort_result.is_err() { - any_abort_errors = true; - } - } - false => writer.shutdown().await?, - } - } - } - - if any_errors { - match any_abort_errors{ - true => return Err(DataFusionError::Internal("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written.".into())), - false => match triggering_error { - Some(e) => return Err(e), - None => return Err(DataFusionError::Internal("Unknown Error encountered during writing to ObjectStore. All writers succesfully aborted.".into())) - } - } - } - - Ok(row_count) -} diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs new file mode 100644 index 000000000000..dbfeb67eaeb9 --- /dev/null +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -0,0 +1,420 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to enabling +//! dividing input stream into multiple output files at execution time + +use std::collections::HashMap; + +use std::sync::Arc; + +use crate::datasource::listing::ListingTableUrl; + +use crate::error::Result; +use crate::physical_plan::SendableRecordBatchStream; + +use arrow_array::builder::UInt64Builder; +use arrow_array::cast::AsArray; +use arrow_array::{downcast_dictionary_array, RecordBatch, StringArray, StructArray}; +use arrow_schema::{DataType, Schema}; +use datafusion_common::cast::as_string_array; +use datafusion_common::DataFusionError; + +use datafusion_execution::TaskContext; + +use futures::StreamExt; +use object_store::path::Path; + +use rand::distributions::DistString; + +use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; +use tokio::task::JoinHandle; + +type RecordBatchReceiver = Receiver; +type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; + +/// Splits a single [SendableRecordBatchStream] into a dynamically determined +/// number of partitions at execution time. The partitions are determined by +/// factors known only at execution time, such as total number of rows and +/// partition column values. The demuxer task communicates to the caller +/// by sending channels over a channel. The inner channels send RecordBatches +/// which should be contained within the same output file. The outer channel +/// is used to send a dynamic number of inner channels, representing a dynamic +/// number of total output files. The caller is also responsible to monitor +/// the demux task for errors and abort accordingly. The single_file_ouput parameter +/// overrides all other settings to force only a single file to be written. +/// partition_by parameter will additionally split the input based on the unique +/// values of a specific column ``` +/// ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌──────▶ │ batch 1 ├────▶...──────▶│ Batch a │ │ Output File1│ +/// │ └───────────┘ └────────────┘ └─────────────┘ +/// │ +/// ┌──────────┐ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌───────────┐ ┌────────────┐ │ │ ├──────▶ │ batch a+1├────▶...──────▶│ Batch b │ │ Output File2│ +/// │ batch 1 ├────▶...──────▶│ Batch N ├─────▶│ Demux ├────────┤ ... └───────────┘ └────────────┘ └─────────────┘ +/// └───────────┘ └────────────┘ │ │ │ +/// └──────────┘ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// └──────▶ │ batch d ├────▶...──────▶│ Batch n │ │ Output FileN│ +/// └───────────┘ └────────────┘ └─────────────┘ +pub(crate) fn start_demuxer_task( + input: SendableRecordBatchStream, + context: &Arc, + partition_by: Option>, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> (JoinHandle>, DemuxedStreamReceiver) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let context = context.clone(); + let task: JoinHandle> = match partition_by { + Some(parts) => { + // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot + // bound this channel without risking a deadlock. + tokio::spawn(async move { + hive_style_partitions_demuxer( + tx, + input, + context, + parts, + base_output_path, + file_extension, + ) + .await + }) + } + None => tokio::spawn(async move { + row_count_demuxer( + tx, + input, + context, + base_output_path, + file_extension, + single_file_output, + ) + .await + }), + }; + + (task, rx) +} + +/// Dynamically partitions input stream to acheive desired maximum rows per file +async fn row_count_demuxer( + mut tx: UnboundedSender<(Path, Receiver)>, + mut input: SendableRecordBatchStream, + context: Arc, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> Result<()> { + let exec_options = &context.session_config().options().execution; + + let max_rows_per_file = exec_options.soft_max_rows_per_output_file; + let max_buffered_batches = exec_options.max_buffered_batches_per_output_file; + let minimum_parallel_files = exec_options.minimum_parallel_output_files; + let mut part_idx = 0; + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let mut open_file_streams = Vec::with_capacity(minimum_parallel_files); + + let mut next_send_steam = 0; + let mut row_counts = Vec::with_capacity(minimum_parallel_files); + + // Overrides if single_file_output is set + let minimum_parallel_files = if single_file_output { + 1 + } else { + minimum_parallel_files + }; + + let max_rows_per_file = if single_file_output { + usize::MAX + } else { + max_rows_per_file + }; + + while let Some(rb) = input.next().await.transpose()? { + // ensure we have at least minimum_parallel_files open + if open_file_streams.len() < minimum_parallel_files { + open_file_streams.push(create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?); + row_counts.push(0); + part_idx += 1; + } else if row_counts[next_send_steam] >= max_rows_per_file { + row_counts[next_send_steam] = 0; + open_file_streams[next_send_steam] = create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?; + part_idx += 1; + } + row_counts[next_send_steam] += rb.num_rows(); + open_file_streams[next_send_steam] + .send(rb) + .await + .map_err(|_| { + DataFusionError::Execution( + "Error sending RecordBatch to file stream!".into(), + ) + })?; + + next_send_steam = (next_send_steam + 1) % minimum_parallel_files; + } + Ok(()) +} + +/// Helper for row count demuxer +fn generate_file_path( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, +) -> Path { + if !single_file_output { + base_output_path + .prefix() + .child(format!("{}_{}.{}", write_id, part_idx, file_extension)) + } else { + base_output_path.prefix().to_owned() + } +} + +/// Helper for row count demuxer +fn create_new_file_stream( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, + max_buffered_batches: usize, + tx: &mut UnboundedSender<(Path, Receiver)>, +) -> Result> { + let file_path = generate_file_path( + base_output_path, + write_id, + part_idx, + file_extension, + single_file_output, + ); + let (tx_file, rx_file) = mpsc::channel(max_buffered_batches / 2); + tx.send((file_path, rx_file)).map_err(|_| { + DataFusionError::Execution("Error sending RecordBatch to file stream!".into()) + })?; + Ok(tx_file) +} + +/// Splits an input stream based on the distinct values of a set of columns +/// Assumes standard hive style partition paths such as +/// /col1=val1/col2=val2/outputfile.parquet +async fn hive_style_partitions_demuxer( + tx: UnboundedSender<(Path, Receiver)>, + mut input: SendableRecordBatchStream, + context: Arc, + partition_by: Vec<(String, DataType)>, + base_output_path: ListingTableUrl, + file_extension: String, +) -> Result<()> { + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let exec_options = &context.session_config().options().execution; + let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file; + + // To support non string partition col types, cast the type to &str first + let mut value_map: HashMap, Sender> = HashMap::new(); + + while let Some(rb) = input.next().await.transpose()? { + // First compute partition key for each row of batch, e.g. (col1=val1, col2=val2, ...) + let all_partition_values = compute_partition_keys_by_row(&rb, &partition_by)?; + + // Next compute how the batch should be split up to take each distinct key to its own batch + let take_map = compute_take_arrays(&rb, all_partition_values); + + // Divide up the batch into distinct partition key batches and send each batch + for (part_key, mut builder) in take_map.into_iter() { + // Take method adapted from https://github.com/lancedb/lance/pull/1337/files + // TODO: upstream RecordBatch::take to arrow-rs + let take_indices = builder.finish(); + let struct_array: StructArray = rb.clone().into(); + let parted_batch = RecordBatch::from( + arrow::compute::take(&struct_array, &take_indices, None)?.as_struct(), + ); + + // Get or create channel for this batch + let part_tx = match value_map.get_mut(&part_key) { + Some(part_tx) => part_tx, + None => { + // Create channel for previously unseen distinct partition key and notify consumer of new file + let (part_tx, part_rx) = tokio::sync::mpsc::channel::( + max_buffered_recordbatches, + ); + let file_path = compute_hive_style_file_path( + &part_key, + &partition_by, + &write_id, + &file_extension, + &base_output_path, + ); + + tx.send((file_path, part_rx)).map_err(|_| { + DataFusionError::Execution( + "Error sending new file stream!".into(), + ) + })?; + + value_map.insert(part_key.clone(), part_tx); + value_map + .get_mut(&part_key) + .ok_or(DataFusionError::Internal( + "Key must exist since it was just inserted!".into(), + ))? + } + }; + + // remove partitions columns + let final_batch_to_send = + remove_partition_by_columns(&parted_batch, &partition_by)?; + + // Finally send the partial batch partitioned by distinct value! + part_tx.send(final_batch_to_send).await.map_err(|_| { + DataFusionError::Internal("Unexpected error sending parted batch!".into()) + })?; + } + } + + Ok(()) +} + +fn compute_partition_keys_by_row<'a>( + rb: &'a RecordBatch, + partition_by: &'a [(String, DataType)], +) -> Result>> { + let mut all_partition_values = vec![]; + + for (col, dtype) in partition_by.iter() { + let mut partition_values = vec![]; + let col_array = + rb.column_by_name(col) + .ok_or(DataFusionError::Execution(format!( + "PartitionBy Column {} does not exist in source data!", + col + )))?; + + match dtype { + DataType::Utf8 => { + let array = as_string_array(col_array)?; + for i in 0..rb.num_rows() { + partition_values.push(array.value(i)); + } + } + DataType::Dictionary(_, _) => { + downcast_dictionary_array!( + col_array => { + let array = col_array.downcast_dict::() + .ok_or(DataFusionError::Execution(format!("it is not yet supported to write to hive partitions with datatype {}", + dtype)))?; + + for val in array.values() { + partition_values.push( + val.ok_or(DataFusionError::Execution(format!("Cannot partition by null value for column {}", col)))? + ); + } + }, + _ => unreachable!(), + ) + } + _ => { + return Err(DataFusionError::NotImplemented(format!( + "it is not yet supported to write to hive partitions with datatype {}", + dtype + ))) + } + } + + all_partition_values.push(partition_values); + } + + Ok(all_partition_values) +} + +fn compute_take_arrays( + rb: &RecordBatch, + all_partition_values: Vec>, +) -> HashMap, UInt64Builder> { + let mut take_map = HashMap::new(); + for i in 0..rb.num_rows() { + let mut part_key = vec![]; + for vals in all_partition_values.iter() { + part_key.push(vals[i].to_owned()); + } + let builder = take_map.entry(part_key).or_insert(UInt64Builder::new()); + builder.append_value(i as u64); + } + take_map +} + +fn remove_partition_by_columns( + parted_batch: &RecordBatch, + partition_by: &[(String, DataType)], +) -> Result { + let end_idx = parted_batch.num_columns() - partition_by.len(); + let non_part_cols = &parted_batch.columns()[..end_idx]; + + let partition_names: Vec<_> = partition_by.iter().map(|(s, _)| s).collect(); + let non_part_schema = Schema::new( + parted_batch + .schema() + .fields() + .iter() + .filter(|f| !partition_names.contains(&f.name())) + .map(|f| (**f).clone()) + .collect::>(), + ); + let final_batch_to_send = + RecordBatch::try_new(Arc::new(non_part_schema), non_part_cols.into())?; + + Ok(final_batch_to_send) +} + +fn compute_hive_style_file_path( + part_key: &[String], + partition_by: &[(String, DataType)], + write_id: &str, + file_extension: &str, + base_output_path: &ListingTableUrl, +) -> Path { + let mut file_path = base_output_path.prefix().clone(); + for j in 0..part_key.len() { + file_path = file_path.child(format!("{}={}", partition_by[j].0, part_key[j])); + } + + file_path.child(format!("{}.{}", write_id, file_extension)) +} diff --git a/datafusion/core/src/datasource/file_format/write/mod.rs b/datafusion/core/src/datasource/file_format/write/mod.rs new file mode 100644 index 000000000000..c481f2accf19 --- /dev/null +++ b/datafusion/core/src/datasource/file_format/write/mod.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to enabling +//! write support for the various file formats + +use std::io::{Error, Write}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::error::Result; + +use arrow_array::RecordBatch; +use datafusion_common::DataFusionError; + +use async_trait::async_trait; +use bytes::Bytes; +use futures::future::BoxFuture; +use object_store::path::Path; +use object_store::{MultipartId, ObjectStore}; +use tokio::io::AsyncWrite; + +pub(crate) mod demux; +pub(crate) mod orchestration; + +/// A buffer with interior mutability shared by the SerializedFileWriter and +/// ObjectStore writer +#[derive(Clone)] +pub(crate) struct SharedBuffer { + /// The inner buffer for reading and writing + /// + /// The lock is used to obtain internal mutability, so no worry about the + /// lock contention. + pub(crate) buffer: Arc>>, +} + +impl SharedBuffer { + pub fn new(capacity: usize) -> Self { + Self { + buffer: Arc::new(futures::lock::Mutex::new(Vec::with_capacity(capacity))), + } + } +} + +impl Write for SharedBuffer { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut buffer = self.buffer.try_lock().unwrap(); + Write::write(&mut *buffer, buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + let mut buffer = self.buffer.try_lock().unwrap(); + Write::flush(&mut *buffer) + } +} + +/// Stores data needed during abortion of MultiPart writers +#[derive(Clone)] +pub(crate) struct MultiPart { + /// A shared reference to the object store + store: Arc, + multipart_id: MultipartId, + location: Path, +} + +impl MultiPart { + /// Create a new `MultiPart` + pub fn new( + store: Arc, + multipart_id: MultipartId, + location: Path, + ) -> Self { + Self { + store, + multipart_id, + location, + } + } +} + +/// A wrapper struct with abort method and writer +pub(crate) struct AbortableWrite { + writer: W, + multipart: MultiPart, +} + +impl AbortableWrite { + /// Create a new `AbortableWrite` instance with the given writer, and write mode. + pub(crate) fn new(writer: W, multipart: MultiPart) -> Self { + Self { writer, multipart } + } + + /// handling of abort for different write modes + pub(crate) fn abort_writer(&self) -> Result>> { + let multi = self.multipart.clone(); + Ok(Box::pin(async move { + multi + .store + .abort_multipart(&multi.location, &multi.multipart_id) + .await + .map_err(DataFusionError::ObjectStore) + })) + } +} + +impl AsyncWrite for AbortableWrite { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.get_mut().writer).poll_write(cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().writer).poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().writer).poll_shutdown(cx) + } +} + +/// A trait that defines the methods required for a RecordBatch serializer. +#[async_trait] +pub trait BatchSerializer: Sync + Send { + /// Asynchronously serializes a `RecordBatch` and returns the serialized bytes. + /// Parameter `initial` signals whether the given batch is the first batch. + /// This distinction is important for certain serializers (like CSV). + async fn serialize(&self, batch: RecordBatch, initial: bool) -> Result; +} + +/// Returns an [`AbortableWrite`] which writes to the given object store location +/// with the specified compression +pub(crate) async fn create_writer( + file_compression_type: FileCompressionType, + location: &Path, + object_store: Arc, +) -> Result>> { + let (multipart_id, writer) = object_store + .put_multipart(location) + .await + .map_err(DataFusionError::ObjectStore)?; + Ok(AbortableWrite::new( + file_compression_type.convert_async_writer(writer)?, + MultiPart::new(object_store, multipart_id, location.clone()), + )) +} diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs new file mode 100644 index 000000000000..9b820a15b280 --- /dev/null +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to +//! orchestrating file serialization, streaming to object store, +//! parallelization, and abort handling + +use std::sync::Arc; + +use super::demux::start_demuxer_task; +use super::{create_writer, AbortableWrite, BatchSerializer}; +use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::physical_plan::FileSinkConfig; +use crate::error::Result; +use crate::physical_plan::SendableRecordBatchStream; + +use arrow_array::RecordBatch; +use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError}; +use datafusion_execution::TaskContext; + +use bytes::Bytes; +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio::sync::mpsc::{self, Receiver}; +use tokio::task::{JoinHandle, JoinSet}; +use tokio::try_join; + +type WriterType = AbortableWrite>; +type SerializerType = Arc; + +/// Serializes a single data stream in parallel and writes to an ObjectStore +/// concurrently. Data order is preserved. In the event of an error, +/// the ObjectStore writer is returned to the caller in addition to an error, +/// so that the caller may handle aborting failed writes. +pub(crate) async fn serialize_rb_stream_to_object_store( + mut data_rx: Receiver, + serializer: Arc, + mut writer: AbortableWrite>, +) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> { + let (tx, mut rx) = + mpsc::channel::>>(100); + let serialize_task = tokio::spawn(async move { + // Some serializers (like CSV) handle the first batch differently than + // subsequent batches, so we track that here. + let mut initial = true; + while let Some(batch) = data_rx.recv().await { + let serializer_clone = serializer.clone(); + let handle = tokio::spawn(async move { + let num_rows = batch.num_rows(); + let bytes = serializer_clone.serialize(batch, initial).await?; + Ok((num_rows, bytes)) + }); + if initial { + initial = false; + } + tx.send(handle).await.map_err(|_| { + internal_datafusion_err!("Unknown error writing to object store") + })?; + } + Ok(()) + }); + + let mut row_count = 0; + while let Some(handle) = rx.recv().await { + match handle.await { + Ok(Ok((cnt, bytes))) => { + match writer.write_all(&bytes).await { + Ok(_) => (), + Err(e) => { + return Err(( + writer, + DataFusionError::Execution(format!( + "Error writing to object store: {e}" + )), + )) + } + }; + row_count += cnt; + } + Ok(Err(e)) => { + // Return the writer along with the error + return Err((writer, e)); + } + Err(e) => { + // Handle task panic or cancellation + return Err(( + writer, + DataFusionError::Execution(format!( + "Serialization task panicked or was cancelled: {e}" + )), + )); + } + } + } + + match serialize_task.await { + Ok(Ok(_)) => (), + Ok(Err(e)) => return Err((writer, e)), + Err(_) => { + return Err(( + writer, + internal_datafusion_err!("Unknown error writing to object store"), + )) + } + }; + Ok((writer, row_count as u64)) +} + +type FileWriteBundle = (Receiver, SerializerType, WriterType); +/// Contains the common logic for serializing RecordBatches and +/// writing the resulting bytes to an ObjectStore. +/// Serialization is assumed to be stateless, i.e. +/// each RecordBatch can be serialized without any +/// dependency on the RecordBatches before or after. +pub(crate) async fn stateless_serialize_and_write_files( + mut rx: Receiver, + tx: tokio::sync::oneshot::Sender, +) -> Result<()> { + let mut row_count = 0; + // tracks if any writers encountered an error triggering the need to abort + let mut any_errors = false; + // tracks the specific error triggering abort + let mut triggering_error = None; + // tracks if any errors were encountered in the process of aborting writers. + // if true, we may not have a guarentee that all written data was cleaned up. + let mut any_abort_errors = false; + let mut join_set = JoinSet::new(); + while let Some((data_rx, serializer, writer)) = rx.recv().await { + join_set.spawn(async move { + serialize_rb_stream_to_object_store(data_rx, serializer, writer).await + }); + } + let mut finished_writers = Vec::new(); + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => match res { + Ok((writer, cnt)) => { + finished_writers.push(writer); + row_count += cnt; + } + Err((writer, e)) => { + finished_writers.push(writer); + any_errors = true; + triggering_error = Some(e); + } + }, + Err(e) => { + // Don't panic, instead try to clean up as many writers as possible. + // If we hit this code, ownership of a writer was not joined back to + // this thread, so we cannot clean it up (hence any_abort_errors is true) + any_errors = true; + any_abort_errors = true; + triggering_error = Some(internal_datafusion_err!( + "Unexpected join error while serializing file {e}" + )); + } + } + } + + // Finalize or abort writers as appropriate + for mut writer in finished_writers.into_iter() { + match any_errors { + true => { + let abort_result = writer.abort_writer(); + if abort_result.is_err() { + any_abort_errors = true; + } + } + false => { + writer.shutdown() + .await + .map_err(|_| internal_datafusion_err!("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!"))?; + } + } + } + + if any_errors { + match any_abort_errors{ + true => return internal_err!("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written."), + false => match triggering_error { + Some(e) => return Err(e), + None => return internal_err!("Unknown Error encountered during writing to ObjectStore. All writers succesfully aborted.") + } + } + } + + tx.send(row_count).map_err(|_| { + internal_datafusion_err!( + "Error encountered while sending row count back to file sink!" + ) + })?; + Ok(()) +} + +/// Orchestrates multipart put of a dynamic number of output files from a single input stream +/// for any statelessly serialized file type. That is, any file type for which each [RecordBatch] +/// can be serialized independently of all other [RecordBatch]s. +pub(crate) async fn stateless_multipart_put( + data: SendableRecordBatchStream, + context: &Arc, + file_extension: String, + get_serializer: Box Arc + Send>, + config: &FileSinkConfig, + compression: FileCompressionType, +) -> Result { + let object_store = context + .runtime_env() + .object_store(&config.object_store_url)?; + + let single_file_output = config.single_file_output; + let base_output_path = &config.table_paths[0]; + let part_cols = if !config.table_partition_cols.is_empty() { + Some(config.table_partition_cols.clone()) + } else { + None + }; + + let (demux_task, mut file_stream_rx) = start_demuxer_task( + data, + context, + part_cols, + base_output_path.clone(), + file_extension, + single_file_output, + ); + + let rb_buffer_size = &context + .session_config() + .options() + .execution + .max_buffered_batches_per_output_file; + + let (tx_file_bundle, rx_file_bundle) = tokio::sync::mpsc::channel(rb_buffer_size / 2); + let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel(); + let write_coordinater_task = tokio::spawn(async move { + stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt).await + }); + while let Some((location, rb_stream)) = file_stream_rx.recv().await { + let serializer = get_serializer(); + let writer = create_writer(compression, &location, object_store.clone()).await?; + + tx_file_bundle + .send((rb_stream, serializer, writer)) + .await + .map_err(|_| { + internal_datafusion_err!( + "Writer receive file bundle channel closed unexpectedly!" + ) + })?; + } + + // Signal to the write coordinater that no more files are coming + drop(tx_file_bundle); + + match try_join!(write_coordinater_task, demux_task) { + Ok((r1, r2)) => { + r1?; + r2?; + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + + let total_count = rx_row_cnt.await.map_err(|_| { + internal_datafusion_err!("Did not receieve row count from write coordinater") + })?; + + Ok(total_count) +} diff --git a/datafusion/core/src/datasource/function.rs b/datafusion/core/src/datasource/function.rs new file mode 100644 index 000000000000..2fd352ee4eb3 --- /dev/null +++ b/datafusion/core/src/datasource/function.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A table that uses a function to generate data + +use super::TableProvider; + +use datafusion_common::Result; +use datafusion_expr::Expr; + +use std::sync::Arc; + +/// A trait for table function implementations +pub trait TableFunctionImpl: Sync + Send { + /// Create a table provider + fn call(&self, args: &[Expr]) -> Result>; +} + +/// A table that uses a function to generate data +pub struct TableFunction { + /// Name of the table function + name: String, + /// Function implementation + fun: Arc, +} + +impl TableFunction { + /// Create a new table function + pub fn new(name: String, fun: Arc) -> Self { + Self { name, fun } + } + + /// Get the name of the table function + pub fn name(&self) -> &str { + &self.name + } + + /// Get the function implementation and generate a table + pub fn create_table_provider(&self, args: &[Expr]) -> Result> { + self.fun.call(args) + } +} diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 476c58b698d8..68de55e1a410 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -36,10 +36,10 @@ use crate::{error::Result, scalar::ScalarValue}; use super::PartitionedFile; use crate::datasource::listing::ListingTableUrl; +use crate::execution::context::SessionState; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::{Column, DFField, DFSchema, DataFusionError}; -use datafusion_expr::expr::ScalarUDF; -use datafusion_expr::{Expr, Volatility}; +use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError}; +use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; use object_store::path::Path; @@ -53,13 +53,13 @@ use object_store::{ObjectMeta, ObjectStore}; pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { let mut is_applicable = true; expr.apply(&mut |expr| { - Ok(match expr { + match expr { Expr::Column(Column { ref name, .. }) => { is_applicable &= col_names.contains(name); if is_applicable { - VisitRecursion::Skip + Ok(VisitRecursion::Skip) } else { - VisitRecursion::Stop + Ok(VisitRecursion::Stop) } } Expr::Literal(_) @@ -88,25 +88,32 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::GetIndexedField { .. } | Expr::GroupingSet(_) - | Expr::Case { .. } => VisitRecursion::Continue, + | Expr::Case { .. } => Ok(VisitRecursion::Continue), Expr::ScalarFunction(scalar_function) => { - match scalar_function.fun.volatility() { - Volatility::Immutable => VisitRecursion::Continue, - // TODO: Stable functions could be `applicable`, but that would require access to the context - Volatility::Stable | Volatility::Volatile => { - is_applicable = false; - VisitRecursion::Stop + match &scalar_function.func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + match fun.volatility() { + Volatility::Immutable => Ok(VisitRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(VisitRecursion::Stop) + } + } } - } - } - Expr::ScalarUDF(ScalarUDF { fun, .. }) => { - match fun.signature.volatility { - Volatility::Immutable => VisitRecursion::Continue, - // TODO: Stable functions could be `applicable`, but that would require access to the context - Volatility::Stable | Volatility::Volatile => { - is_applicable = false; - VisitRecursion::Stop + ScalarFunctionDefinition::UDF(fun) => { + match fun.signature().volatility { + Volatility::Immutable => Ok(VisitRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(VisitRecursion::Stop) + } + } + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") } } } @@ -115,17 +122,15 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { // - AGGREGATE, WINDOW and SORT should not end up in filter conditions, except maybe in some edge cases // - Can `Wildcard` be considered as a `Literal`? // - ScalarVariable could be `applicable`, but that would require access to the context - Expr::AggregateUDF { .. } - | Expr::AggregateFunction { .. } + Expr::AggregateFunction { .. } | Expr::Sort { .. } | Expr::WindowFunction { .. } - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard { .. } | Expr::Placeholder(_) => { is_applicable = false; - VisitRecursion::Stop + Ok(VisitRecursion::Stop) } - }) + } }) .unwrap(); is_applicable @@ -136,12 +141,18 @@ const CONCURRENCY_LIMIT: usize = 100; /// Partition the list of files into `n` groups pub fn split_files( - partitioned_files: Vec, + mut partitioned_files: Vec, n: usize, ) -> Vec> { if partitioned_files.is_empty() { return vec![]; } + + // ObjectStore::list does not guarantee any consistent order and for some + // implementations such as LocalFileSystem, it may be inconsistent. Thus + // Sort files by path to ensure consistent plans when run more than once. + partitioned_files.sort_by(|a, b| a.path().cmp(b.path())); + // effectively this is div with rounding up instead of truncating let chunk_size = (partitioned_files.len() + n - 1) / n; partitioned_files @@ -275,7 +286,10 @@ async fn prune_partitions( // Applies `filter` to `batch` returning `None` on error let do_filter = |filter| -> Option { let expr = create_physical_expr(filter, &df_schema, &schema, &props).ok()?; - Some(expr.evaluate(&batch).ok()?.into_array(partitions.len())) + expr.evaluate(&batch) + .ok()? + .into_array(partitions.len()) + .ok() }; //.Compute the conjunction of the filters, ignoring errors @@ -315,6 +329,7 @@ async fn prune_partitions( /// `filters` might contain expressions that can be resolved only at the /// file level (e.g. Parquet row group pruning). pub async fn pruned_partition_list<'a>( + ctx: &'a SessionState, store: &'a dyn ObjectStore, table_path: &'a ListingTableUrl, filters: &'a [Expr], @@ -325,7 +340,8 @@ pub async fn pruned_partition_list<'a>( if partition_cols.is_empty() { return Ok(Box::pin( table_path - .list_all_files(store, file_extension) + .list_all_files(ctx, store, file_extension) + .await? .map_ok(|object_meta| object_meta.into()), )); } @@ -356,14 +372,13 @@ pub async fn pruned_partition_list<'a>( Some(files) => files, None => { trace!("Recursively listing partition {}", partition.path); - let s = store.list(Some(&partition.path)).await?; - s.try_collect().await? + store.list(Some(&partition.path)).try_collect().await? } }; - let files = files.into_iter().filter(move |o| { let extension_match = o.location.as_ref().ends_with(file_extension); - let glob_match = table_path.contains(&o.location); + // here need to scan subdirectories(`listing_table_ignore_subdirectory` = false) + let glob_match = table_path.contains(&o.location, false); extension_match && glob_match }); @@ -422,7 +437,7 @@ mod tests { use futures::StreamExt; use crate::logical_expr::{case, col, lit}; - use crate::test::object_store::make_test_store; + use crate::test::object_store::make_test_store_and_state; use super::*; @@ -468,12 +483,13 @@ mod tests { #[tokio::test] async fn test_pruned_partition_list_empty() { - let store = make_test_store(&[ + let (store, state) = make_test_store_and_state(&[ ("tablepath/mypartition=val1/notparquetfile", 100), ("tablepath/file.parquet", 100), ]); let filter = Expr::eq(col("mypartition"), lit("val1")); let pruned = pruned_partition_list( + &state, store.as_ref(), &ListingTableUrl::parse("file:///tablepath/").unwrap(), &[filter], @@ -490,13 +506,14 @@ mod tests { #[tokio::test] async fn test_pruned_partition_list() { - let store = make_test_store(&[ + let (store, state) = make_test_store_and_state(&[ ("tablepath/mypartition=val1/file.parquet", 100), ("tablepath/mypartition=val2/file.parquet", 100), ("tablepath/mypartition=val1/other=val3/file.parquet", 100), ]); let filter = Expr::eq(col("mypartition"), lit("val1")); let pruned = pruned_partition_list( + &state, store.as_ref(), &ListingTableUrl::parse("file:///tablepath/").unwrap(), &[filter], @@ -515,24 +532,18 @@ mod tests { f1.object_meta.location.as_ref(), "tablepath/mypartition=val1/file.parquet" ); - assert_eq!( - &f1.partition_values, - &[ScalarValue::Utf8(Some(String::from("val1"))),] - ); + assert_eq!(&f1.partition_values, &[ScalarValue::from("val1")]); let f2 = &pruned[1]; assert_eq!( f2.object_meta.location.as_ref(), "tablepath/mypartition=val1/other=val3/file.parquet" ); - assert_eq!( - f2.partition_values, - &[ScalarValue::Utf8(Some(String::from("val1"))),] - ); + assert_eq!(f2.partition_values, &[ScalarValue::from("val1"),]); } #[tokio::test] async fn test_pruned_partition_list_multi() { - let store = make_test_store(&[ + let (store, state) = make_test_store_and_state(&[ ("tablepath/part1=p1v1/file.parquet", 100), ("tablepath/part1=p1v2/part2=p2v1/file1.parquet", 100), ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), @@ -544,6 +555,7 @@ mod tests { // filter3 cannot be resolved at partition pruning let filter3 = Expr::eq(col("part2"), col("other")); let pruned = pruned_partition_list( + &state, store.as_ref(), &ListingTableUrl::parse("file:///tablepath/").unwrap(), &[filter1, filter2, filter3], @@ -567,10 +579,7 @@ mod tests { ); assert_eq!( &f1.partition_values, - &[ - ScalarValue::Utf8(Some(String::from("p1v2"))), - ScalarValue::Utf8(Some(String::from("p2v1"))) - ] + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1"),] ); let f2 = &pruned[1]; assert_eq!( @@ -579,10 +588,7 @@ mod tests { ); assert_eq!( &f2.partition_values, - &[ - ScalarValue::Utf8(Some(String::from("p1v2"))), - ScalarValue::Utf8(Some(String::from("p2v1"))) - ] + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1")] ); } diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index 8b0f021f0277..e7583501f9d9 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -31,9 +31,7 @@ use std::pin::Pin; use std::sync::Arc; pub use self::url::ListingTableUrl; -pub use table::{ - ListingOptions, ListingTable, ListingTableConfig, ListingTableInsertMode, -}; +pub use table::{ListingOptions, ListingTable, ListingTableConfig}; /// Stream of files get listed from object store pub type PartitionedFileStream = @@ -42,7 +40,7 @@ pub type PartitionedFileStream = /// Only scan a subset of Row Groups from the Parquet file whose data "midpoint" /// lies within the [start, end) byte offsets. This option can be used to scan non-overlapping /// sections of a Parquet file in parallel. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] pub struct FileRange { /// Range start pub start: i64, @@ -72,16 +70,16 @@ pub struct PartitionedFile { /// An optional field for user defined per object metadata pub extensions: Option>, } - impl PartitionedFile { /// Create a simple file without metadata or partition - pub fn new(path: String, size: u64) -> Self { + pub fn new(path: impl Into, size: u64) -> Self { Self { object_meta: ObjectMeta { - location: Path::from(path), + location: Path::from(path.into()), last_modified: chrono::Utc.timestamp_nanos(0), size: size as usize, e_tag: None, + version: None, }, partition_values: vec![], range: None, @@ -97,11 +95,13 @@ impl PartitionedFile { last_modified: chrono::Utc.timestamp_nanos(0), size: size as usize, e_tag: None, + version: None, }, partition_values: vec![], - range: Some(FileRange { start, end }), + range: None, extensions: None, } + .with_range(start, end) } /// Return a file reference from the given path @@ -109,6 +109,17 @@ impl PartitionedFile { let size = std::fs::metadata(path.clone())?.len(); Ok(Self::new(path, size)) } + + /// Return the path of this partitioned file + pub fn path(&self) -> &Path { + &self.object_meta.location + } + + /// Update the file to only scan the specified range (in bytes) + pub fn with_range(mut self, start: i64, end: i64) -> Self { + self.range = Some(FileRange { start, end }); + self + } } impl From for PartitionedFile { diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 8360847e1bd1..a7af1bf1be28 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -17,50 +17,51 @@ //! The table implementation. +use std::collections::HashMap; use std::str::FromStr; use std::{any::Any, sync::Arc}; -use arrow::compute::SortOptions; -use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; -use arrow_schema::Schema; -use async_trait::async_trait; -use datafusion_common::FileTypeWriterOptions; -use datafusion_common::{internal_err, plan_err, project_schema, SchemaExt, ToDFSchema}; -use datafusion_expr::expr::Sort; -use datafusion_optimizer::utils::conjunction; -use datafusion_physical_expr::{create_physical_expr, LexOrdering, PhysicalSortExpr}; -use futures::{future, stream, StreamExt, TryStreamExt}; +use super::helpers::{expr_applicable_for_cols, pruned_partition_list, split_files}; +use super::PartitionedFile; -use crate::datasource::file_format::file_compression_type::{ - FileCompressionType, FileTypeExt, -}; -use crate::datasource::physical_plan::{ - is_plan_streaming, FileScanConfig, FileSinkConfig, -}; +#[cfg(feature = "parquet")] +use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::{ + create_ordering, file_format::{ - arrow::ArrowFormat, avro::AvroFormat, csv::CsvFormat, json::JsonFormat, - parquet::ParquetFormat, FileFormat, + arrow::ArrowFormat, + avro::AvroFormat, + csv::CsvFormat, + file_compression_type::{FileCompressionType, FileTypeExt}, + json::JsonFormat, + FileFormat, }, get_statistics_with_limit, listing::ListingTableUrl, + physical_plan::{FileScanConfig, FileSinkConfig}, TableProvider, TableType, }; -use crate::logical_expr::TableProviderFilterPushDown; -use crate::physical_plan; use crate::{ error::{DataFusionError, Result}, execution::context::SessionState, - logical_expr::Expr, + logical_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}, physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}, }; -use datafusion_common::FileType; + +use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; +use arrow_schema::Schema; +use datafusion_common::{ + internal_err, plan_err, project_schema, Constraints, FileType, FileTypeWriterOptions, + SchemaExt, ToDFSchema, +}; use datafusion_execution::cache::cache_manager::FileStatisticsCache; use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; +use datafusion_physical_expr::{ + create_physical_expr, LexOrdering, PhysicalSortRequirement, +}; -use super::PartitionedFile; - -use super::helpers::{expr_applicable_for_cols, pruned_partition_list, split_files}; +use async_trait::async_trait; +use futures::{future, stream, StreamExt, TryStreamExt}; /// Configuration for creating a [`ListingTable`] #[derive(Debug, Clone)] @@ -147,6 +148,7 @@ impl ListingTableConfig { FileType::JSON => Arc::new( JsonFormat::default().with_file_compression_type(file_compression_type), ), + #[cfg(feature = "parquet")] FileType::PARQUET => Arc::new(ParquetFormat::default()), }; @@ -155,7 +157,7 @@ impl ListingTableConfig { /// Infer `ListingOptions` based on `table_path` suffix. pub async fn infer_options(self, state: &SessionState) -> Result { - let store = if let Some(url) = self.table_paths.get(0) { + let store = if let Some(url) = self.table_paths.first() { state.runtime_env().object_store(url)? } else { return Ok(self); @@ -163,9 +165,10 @@ impl ListingTableConfig { let file = self .table_paths - .get(0) + .first() .unwrap() - .list_all_files(store.as_ref(), "") + .list_all_files(state, store.as_ref(), "") + .await? .next() .await .ok_or_else(|| DataFusionError::Internal("No files for table".into()))??; @@ -188,7 +191,7 @@ impl ListingTableConfig { pub async fn infer_schema(self, state: &SessionState) -> Result { match self.options { Some(options) => { - let schema = if let Some(url) = self.table_paths.get(0) { + let schema = if let Some(url) = self.table_paths.first() { options.infer_schema(state, url).await? } else { Arc::new(Schema::empty()) @@ -210,33 +213,6 @@ impl ListingTableConfig { } } -#[derive(Debug, Clone)] -///controls how new data should be inserted to a ListingTable -pub enum ListingTableInsertMode { - ///Data should be appended to an existing file - AppendToFile, - ///Data is appended as new files in existing TablePaths - AppendNewFiles, - ///Throw an error if insert into is attempted on this table - Error, -} - -impl FromStr for ListingTableInsertMode { - type Err = DataFusionError; - fn from_str(s: &str) -> Result { - let s_lower = s.to_lowercase(); - match s_lower.as_str() { - "append_to_file" => Ok(ListingTableInsertMode::AppendToFile), - "append_new_files" => Ok(ListingTableInsertMode::AppendNewFiles), - "error" => Ok(ListingTableInsertMode::Error), - _ => Err(DataFusionError::Plan(format!( - "Unknown or unsupported insert mode {s}. Supported options are \ - append_to_file, append_new_files, and error." - ))), - } - } -} - /// Options for creating a [`ListingTable`] #[derive(Clone, Debug)] pub struct ListingOptions { @@ -270,16 +246,6 @@ pub struct ListingOptions { /// multiple equivalent orderings, the outer `Vec` will have a /// single element. pub file_sort_order: Vec>, - /// Infinite source means that the input is not guaranteed to end. - /// Currently, CSV, JSON, and AVRO formats are supported. - /// In order to support infinite inputs, DataFusion may adjust query - /// plans (e.g. joins) to run the given query in full pipelining mode. - pub infinite_source: bool, - /// This setting controls how inserts to this table should be handled - pub insert_mode: ListingTableInsertMode, - /// This setting when true indicates that the table is backed by a single file. - /// Any inserts to the table may only append to this existing file. - pub single_file: bool, /// This setting holds file format specific options which should be used /// when inserting into this table. pub file_type_write_options: Option, @@ -300,31 +266,10 @@ impl ListingOptions { collect_stat: true, target_partitions: 1, file_sort_order: vec![], - infinite_source: false, - insert_mode: ListingTableInsertMode::AppendToFile, - single_file: false, file_type_write_options: None, } } - /// Set unbounded assumption on [`ListingOptions`] and returns self. - /// - /// ``` - /// use std::sync::Arc; - /// use datafusion::datasource::{listing::ListingOptions, file_format::csv::CsvFormat}; - /// use datafusion::prelude::SessionContext; - /// let ctx = SessionContext::new(); - /// let listing_options = ListingOptions::new(Arc::new( - /// CsvFormat::default() - /// )).with_infinite_source(true); - /// - /// assert_eq!(listing_options.infinite_source, true); - /// ``` - pub fn with_infinite_source(mut self, infinite_source: bool) -> Self { - self.infinite_source = infinite_source; - self - } - /// Set file extension on [`ListingOptions`] and returns self. /// /// ``` @@ -472,18 +417,6 @@ impl ListingOptions { self } - /// Configure how insertions to this table should be handled. - pub fn with_insert_mode(mut self, insert_mode: ListingTableInsertMode) -> Self { - self.insert_mode = insert_mode; - self - } - - /// Configure if this table is backed by a sigle file - pub fn with_single_file(mut self, single_file: bool) -> Self { - self.single_file = single_file; - self - } - /// Configure file format specific writing options. pub fn with_write_options( mut self, @@ -507,7 +440,8 @@ impl ListingOptions { let store = state.runtime_env().object_store(table_path)?; let files: Vec<_> = table_path - .list_all_files(store.as_ref(), &self.file_extension) + .list_all_files(state, store.as_ref(), &self.file_extension) + .await? .try_collect() .await?; @@ -522,7 +456,7 @@ impl ListingOptions { /// /// # Features /// -/// 1. Merges schemas if the files have compatible but not indentical schemas +/// 1. Merges schemas if the files have compatible but not identical schemas /// /// 2. Hive-style partitioning support, where a path such as /// `/files/date=1/1/2022/data.parquet` is injected as a `date` column. @@ -589,7 +523,8 @@ pub struct ListingTable { options: ListingOptions, definition: Option, collected_statistics: FileStatisticsCache, - infinite_source: bool, + constraints: Constraints, + column_defaults: HashMap, } impl ListingTable { @@ -617,7 +552,6 @@ impl ListingTable { for (part_col_name, part_col_type) in &options.table_partition_cols { builder.push(Field::new(part_col_name, part_col_type.clone(), false)); } - let infinite_source = options.infinite_source; let table = Self { table_paths: config.table_paths, @@ -626,12 +560,28 @@ impl ListingTable { options, definition: None, collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), - infinite_source, + constraints: Constraints::empty(), + column_defaults: HashMap::new(), }; Ok(table) } + /// Assign constraints + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self + } + + /// Assign column defaults + pub fn with_column_defaults( + mut self, + column_defaults: HashMap, + ) -> Self { + self.column_defaults = column_defaults; + self + } + /// Set the [`FileStatisticsCache`] used to cache parquet file statistics. /// /// Setting a statistics cache on the `SessionContext` can avoid refetching statistics @@ -662,34 +612,7 @@ impl ListingTable { /// If file_sort_order is specified, creates the appropriate physical expressions fn try_create_output_ordering(&self) -> Result> { - let mut all_sort_orders = vec![]; - - for exprs in &self.options.file_sort_order { - // Construct PhsyicalSortExpr objects from Expr objects: - let sort_exprs = exprs - .iter() - .map(|expr| { - if let Expr::Sort(Sort { expr, asc, nulls_first }) = expr { - if let Expr::Column(col) = expr.as_ref() { - let expr = physical_plan::expressions::col(&col.name, self.table_schema.as_ref())?; - Ok(PhysicalSortExpr { - expr, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, - }) - } else { - plan_err!("Expected single column references in output_ordering, got {expr}") - } - } else { - plan_err!("Expected Expr::Sort in output_ordering, but got {expr}") - } - }) - .collect::>>()?; - all_sort_orders.push(sort_exprs); - } - Ok(all_sort_orders) + create_ordering(&self.table_schema, &self.options.file_sort_order) } } @@ -703,6 +626,10 @@ impl TableProvider for ListingTable { Arc::clone(&self.table_schema) } + fn constraints(&self) -> Option<&Constraints> { + Some(&self.constraints) + } + fn table_type(&self) -> TableType { TableType::Base } @@ -721,7 +648,7 @@ impl TableProvider for ListingTable { if partitioned_file_lists.is_empty() { let schema = self.schema(); let projected_schema = project_schema(&schema, projection)?; - return Ok(Arc::new(EmptyExec::new(false, projected_schema))); + return Ok(Arc::new(EmptyExec::new(projected_schema))); } // extract types of partition columns @@ -729,15 +656,7 @@ impl TableProvider for ListingTable { .options .table_partition_cols .iter() - .map(|col| { - Ok(( - col.0.to_owned(), - self.table_schema - .field_with_name(&col.0)? - .data_type() - .clone(), - )) - }) + .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) .collect::>>()?; let filters = if let Some(expr) = conjunction(filters.to_vec()) { @@ -754,10 +673,10 @@ impl TableProvider for ListingTable { None }; - let object_store_url = if let Some(url) = self.table_paths.get(0) { + let object_store_url = if let Some(url) = self.table_paths.first() { url.object_store() } else { - return Ok(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty())))); + return Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))); }; // create the execution plan self.options @@ -773,7 +692,6 @@ impl TableProvider for ListingTable { limit, output_ordering: self.try_create_output_ordering()?, table_partition_cols, - infinite_source: self.infinite_source, }, filters.as_ref(), ) @@ -813,37 +731,29 @@ impl TableProvider for ListingTable { overwrite: bool, ) -> Result> { // Check that the schema of the plan matches the schema of this table. - if !self.schema().equivalent_names_and_types(&input.schema()) { + if !self + .schema() + .logically_equivalent_names_and_types(&input.schema()) + { return plan_err!( // Return an error if schema of the input query does not match with the table schema. "Inserting query must have the same schema with the table." ); } - if self.table_paths().len() > 1 { + let table_path = &self.table_paths()[0]; + if !table_path.is_collection() { return plan_err!( - "Writing to a table backed by multiple partitions is not supported yet" + "Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`. \ + To append to an existing file use StreamTable, e.g. by using CREATE UNBOUNDED EXTERNAL TABLE" ); } - // TODO support inserts to sorted tables which preserve sort_order - // Inserts currently make no effort to preserve sort_order. This could lead to - // incorrect query results on the table after inserting incorrectly sorted data. - let unsorted: Vec> = vec![]; - if self.options.file_sort_order != unsorted { - return Err( - DataFusionError::NotImplemented( - "Writing to a sorted listing table via insert into is not supported yet. \ - To write to this table in the meantime, register an equivalent table with \ - file_sort_order = vec![]".into()) - ); - } - - let table_path = &self.table_paths()[0]; // Get the object store for the table path. let store = state.runtime_env().object_store(table_path)?; let file_list_stream = pruned_partition_list( + state, store.as_ref(), table_path, &[], @@ -853,31 +763,6 @@ impl TableProvider for ListingTable { .await?; let file_groups = file_list_stream.try_collect::>().await?; - //if we are writing a single output_partition to a table backed by a single file - //we can append to that file. Otherwise, we can write new files into the directory - //adding new files to the listing table in order to insert to the table. - let input_partitions = input.output_partitioning().partition_count(); - let writer_mode = match self.options.insert_mode { - ListingTableInsertMode::AppendToFile => { - if input_partitions > file_groups.len() { - return Err(DataFusionError::Plan(format!( - "Cannot append {input_partitions} partitions to {} files!", - file_groups.len() - ))); - } - - crate::datasource::file_format::write::FileWriterMode::Append - } - ListingTableInsertMode::AppendNewFiles => { - crate::datasource::file_format::write::FileWriterMode::PutMultipart - } - ListingTableInsertMode::Error => { - return plan_err!( - "Invalid plan attempting write to table with TableWriteMode::Error!" - ); - } - }; - let file_format = self.options().format.as_ref(); let file_type_writer_options = match &self.options().file_type_write_options { @@ -895,24 +780,41 @@ impl TableProvider for ListingTable { file_groups, output_schema: self.schema(), table_partition_cols: self.options.table_partition_cols.clone(), - writer_mode, - // A plan can produce finite number of rows even if it has unbounded sources, like LIMIT - // queries. Thus, we can check if the plan is streaming to ensure file sink input is - // unbounded. When `unbounded_input` flag is `true` for sink, we occasionally call `yield_now` - // to consume data at the input. When `unbounded_input` flag is `false` (e.g non-streaming data), - // all of the data at the input is sink after execution finishes. See discussion for rationale: - // https://github.com/apache/arrow-datafusion/pull/7610#issuecomment-1728979918 - unbounded_input: is_plan_streaming(&input)?, - single_file_output: self.options.single_file, + single_file_output: false, overwrite, file_type_writer_options, }; + let unsorted: Vec> = vec![]; + let order_requirements = if self.options().file_sort_order != unsorted { + // Multiple sort orders in outer vec are equivalent, so we pass only the first one + let ordering = self + .try_create_output_ordering()? + .first() + .ok_or(DataFusionError::Internal( + "Expected ListingTable to have a sort order, but none found!".into(), + ))? + .clone(); + // Converts Vec> into type required by execution plan to specify its required input ordering + Some( + ordering + .into_iter() + .map(PhysicalSortRequirement::from) + .collect::>(), + ) + } else { + None + }; + self.options() .format - .create_writer_physical_plan(input, state, config) + .create_writer_physical_plan(input, state, config, order_requirements) .await } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.column_defaults.get(column) + } } impl ListingTable { @@ -925,14 +827,15 @@ impl ListingTable { filters: &'a [Expr], limit: Option, ) -> Result<(Vec>, Statistics)> { - let store = if let Some(url) = self.table_paths.get(0) { + let store = if let Some(url) = self.table_paths.first() { ctx.runtime_env().object_store(url)? } else { - return Ok((vec![], Statistics::default())); + return Ok((vec![], Statistics::new_unknown(&self.file_schema))); }; // list files (with partitions) let file_list = future::try_join_all(self.table_paths.iter().map(|table_path| { pruned_partition_list( + ctx, store.as_ref(), table_path, filters, @@ -941,14 +844,12 @@ impl ListingTable { ) })) .await?; - let file_list = stream::iter(file_list).flatten(); - // collect the statistics if required by the config let files = file_list .map(|part_file| async { let part_file = part_file?; - let mut statistics_result = Statistics::default(); + let mut statistics_result = Statistics::new_unknown(&self.file_schema); if self.options.collect_stat { let statistics_cache = self.collected_statistics.clone(); match statistics_cache.get_with_extra( @@ -996,58 +897,31 @@ impl ListingTable { #[cfg(test)] mod tests { + use std::collections::HashMap; + use super::*; + #[cfg(feature = "parquet")] + use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::{provider_as_source, MemTable}; use crate::execution::options::ArrowReadOptions; use crate::physical_plan::collect; use crate::prelude::*; use crate::{ assert_batches_eq, - datasource::file_format::{ - avro::AvroFormat, file_compression_type::FileTypeExt, parquet::ParquetFormat, - }, - execution::options::ReadOptions, + datasource::file_format::avro::AvroFormat, logical_expr::{col, lit}, test::{columns, object_store::register_test_store}, }; + use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; - use datafusion_common::assert_contains; - use datafusion_common::GetExt; - use datafusion_expr::LogicalPlanBuilder; - use rstest::*; - use std::collections::HashMap; - use std::fs::File; + use arrow_schema::SortOptions; + use datafusion_common::stats::Precision; + use datafusion_common::{assert_contains, GetExt, ScalarValue}; + use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; + use datafusion_physical_expr::PhysicalSortExpr; use tempfile::TempDir; - /// It creates dummy file and checks if it can create unbounded input executors. - async fn unbounded_table_helper( - file_type: FileType, - listing_option: ListingOptions, - infinite_data: bool, - ) -> Result<()> { - let ctx = SessionContext::new(); - register_test_store( - &ctx, - &[(&format!("table/file{}", file_type.get_ext()), 100)], - ); - - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); - - let table_path = ListingTableUrl::parse("test:///table/").unwrap(); - let config = ListingTableConfig::new(table_path) - .with_listing_options(listing_option) - .with_schema(Arc::new(schema)); - // Create a table - let table = ListingTable::try_new(config)?; - // Create executor from table - let source_exec = table.scan(&ctx.state(), None, &[], None).await?; - - assert_eq!(source_exec.unbounded_output(&[])?, infinite_data); - - Ok(()) - } - #[tokio::test] async fn read_single_file() -> Result<()> { let ctx = SessionContext::new(); @@ -1063,12 +937,13 @@ mod tests { assert_eq!(exec.output_partitioning().partition_count(), 1); // test metadata - assert_eq!(exec.statistics().num_rows, Some(8)); - assert_eq!(exec.statistics().total_byte_size, Some(671)); + assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); + assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); Ok(()) } + #[cfg(feature = "parquet")] #[tokio::test] async fn load_table_stats_by_default() -> Result<()> { let testdata = crate::test_util::parquet_test_data(); @@ -1086,12 +961,13 @@ mod tests { let table = ListingTable::try_new(config)?; let exec = table.scan(&state, None, &[], None).await?; - assert_eq!(exec.statistics().num_rows, Some(8)); - assert_eq!(exec.statistics().total_byte_size, Some(671)); + assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); + assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); Ok(()) } + #[cfg(feature = "parquet")] #[tokio::test] async fn load_table_stats_when_no_stats() -> Result<()> { let testdata = crate::test_util::parquet_test_data(); @@ -1110,12 +986,13 @@ mod tests { let table = ListingTable::try_new(config)?; let exec = table.scan(&state, None, &[], None).await?; - assert_eq!(exec.statistics().num_rows, None); - assert_eq!(exec.statistics().total_byte_size, None); + assert_eq!(exec.statistics()?.num_rows, Precision::Absent); + assert_eq!(exec.statistics()?.total_byte_size, Precision::Absent); Ok(()) } + #[cfg(feature = "parquet")] #[tokio::test] async fn test_try_create_output_ordering() { let testdata = crate::test_util::parquet_test_data(); @@ -1252,99 +1129,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn unbounded_csv_table_without_schema() -> Result<()> { - let tmp_dir = TempDir::new()?; - let file_path = tmp_dir.path().join("dummy.csv"); - File::create(file_path)?; - let ctx = SessionContext::new(); - let error = ctx - .register_csv( - "test", - tmp_dir.path().to_str().unwrap(), - CsvReadOptions::new().mark_infinite(true), - ) - .await - .unwrap_err(); - match error { - DataFusionError::Plan(_) => Ok(()), - val => Err(val), - } - } - - #[tokio::test] - async fn unbounded_json_table_without_schema() -> Result<()> { - let tmp_dir = TempDir::new()?; - let file_path = tmp_dir.path().join("dummy.json"); - File::create(file_path)?; - let ctx = SessionContext::new(); - let error = ctx - .register_json( - "test", - tmp_dir.path().to_str().unwrap(), - NdJsonReadOptions::default().mark_infinite(true), - ) - .await - .unwrap_err(); - match error { - DataFusionError::Plan(_) => Ok(()), - val => Err(val), - } - } - - #[tokio::test] - async fn unbounded_avro_table_without_schema() -> Result<()> { - let tmp_dir = TempDir::new()?; - let file_path = tmp_dir.path().join("dummy.avro"); - File::create(file_path)?; - let ctx = SessionContext::new(); - let error = ctx - .register_avro( - "test", - tmp_dir.path().to_str().unwrap(), - AvroReadOptions::default().mark_infinite(true), - ) - .await - .unwrap_err(); - match error { - DataFusionError::Plan(_) => Ok(()), - val => Err(val), - } - } - - #[rstest] - #[tokio::test] - async fn unbounded_csv_table( - #[values(true, false)] infinite_data: bool, - ) -> Result<()> { - let config = CsvReadOptions::new().mark_infinite(infinite_data); - let session_config = SessionConfig::new().with_target_partitions(1); - let listing_options = config.to_listing_options(&session_config); - unbounded_table_helper(FileType::CSV, listing_options, infinite_data).await - } - - #[rstest] - #[tokio::test] - async fn unbounded_json_table( - #[values(true, false)] infinite_data: bool, - ) -> Result<()> { - let config = NdJsonReadOptions::default().mark_infinite(infinite_data); - let session_config = SessionConfig::new().with_target_partitions(1); - let listing_options = config.to_listing_options(&session_config); - unbounded_table_helper(FileType::JSON, listing_options, infinite_data).await - } - - #[rstest] - #[tokio::test] - async fn unbounded_avro_table( - #[values(true, false)] infinite_data: bool, - ) -> Result<()> { - let config = AvroReadOptions::default().mark_infinite(infinite_data); - let session_config = SessionConfig::new().with_target_partitions(1); - let listing_options = config.to_listing_options(&session_config); - unbounded_table_helper(FileType::AVRO, listing_options, infinite_data).await - } - #[tokio::test] async fn test_assert_list_files_for_scan_grouping() -> Result<()> { // more expected partitions than files @@ -1565,56 +1349,73 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_insert_into_append_to_json_file() -> Result<()> { - helper_test_insert_into_append_to_existing_files( - FileType::JSON, - FileCompressionType::UNCOMPRESSED, - None, - ) - .await?; - Ok(()) - } - #[tokio::test] async fn test_insert_into_append_new_json_files() -> Result<()> { + let mut config_map: HashMap = HashMap::new(); + config_map.insert("datafusion.execution.batch_size".into(), "10".into()); + config_map.insert( + "datafusion.execution.soft_max_rows_per_output_file".into(), + "10".into(), + ); helper_test_append_new_files_to_table( FileType::JSON, FileCompressionType::UNCOMPRESSED, - None, + Some(config_map), + 2, ) .await?; Ok(()) } #[tokio::test] - async fn test_insert_into_append_to_csv_file() -> Result<()> { - helper_test_insert_into_append_to_existing_files( + async fn test_insert_into_append_new_csv_files() -> Result<()> { + let mut config_map: HashMap = HashMap::new(); + config_map.insert("datafusion.execution.batch_size".into(), "10".into()); + config_map.insert( + "datafusion.execution.soft_max_rows_per_output_file".into(), + "10".into(), + ); + helper_test_append_new_files_to_table( FileType::CSV, FileCompressionType::UNCOMPRESSED, - None, + Some(config_map), + 2, ) .await?; Ok(()) } #[tokio::test] - async fn test_insert_into_append_new_csv_files() -> Result<()> { + async fn test_insert_into_append_2_new_parquet_files_defaults() -> Result<()> { + let mut config_map: HashMap = HashMap::new(); + config_map.insert("datafusion.execution.batch_size".into(), "10".into()); + config_map.insert( + "datafusion.execution.soft_max_rows_per_output_file".into(), + "10".into(), + ); helper_test_append_new_files_to_table( - FileType::CSV, + FileType::PARQUET, FileCompressionType::UNCOMPRESSED, - None, + Some(config_map), + 2, ) .await?; Ok(()) } #[tokio::test] - async fn test_insert_into_append_new_parquet_files_defaults() -> Result<()> { + async fn test_insert_into_append_1_new_parquet_files_defaults() -> Result<()> { + let mut config_map: HashMap = HashMap::new(); + config_map.insert("datafusion.execution.batch_size".into(), "20".into()); + config_map.insert( + "datafusion.execution.soft_max_rows_per_output_file".into(), + "20".into(), + ); helper_test_append_new_files_to_table( FileType::PARQUET, FileCompressionType::UNCOMPRESSED, - None, + Some(config_map), + 1, ) .await?; Ok(()) @@ -1622,13 +1423,8 @@ mod tests { #[tokio::test] async fn test_insert_into_sql_csv_defaults() -> Result<()> { - helper_test_insert_into_sql( - "csv", - FileCompressionType::UNCOMPRESSED, - "OPTIONS (insert_mode 'append_new_files')", - None, - ) - .await?; + helper_test_insert_into_sql("csv", FileCompressionType::UNCOMPRESSED, "", None) + .await?; Ok(()) } @@ -1637,8 +1433,7 @@ mod tests { helper_test_insert_into_sql( "csv", FileCompressionType::UNCOMPRESSED, - "WITH HEADER ROW \ - OPTIONS (insert_mode 'append_new_files')", + "WITH HEADER ROW", None, ) .await?; @@ -1647,13 +1442,8 @@ mod tests { #[tokio::test] async fn test_insert_into_sql_json_defaults() -> Result<()> { - helper_test_insert_into_sql( - "json", - FileCompressionType::UNCOMPRESSED, - "OPTIONS (insert_mode 'append_new_files')", - None, - ) - .await?; + helper_test_insert_into_sql("json", FileCompressionType::UNCOMPRESSED, "", None) + .await?; Ok(()) } @@ -1741,6 +1531,11 @@ mod tests { #[tokio::test] async fn test_insert_into_append_new_parquet_files_session_overrides() -> Result<()> { let mut config_map: HashMap = HashMap::new(); + config_map.insert("datafusion.execution.batch_size".into(), "10".into()); + config_map.insert( + "datafusion.execution.soft_max_rows_per_output_file".into(), + "10".into(), + ); config_map.insert( "datafusion.execution.parquet.compression".into(), "zstd(5)".into(), @@ -1801,10 +1596,12 @@ mod tests { "datafusion.execution.parquet.write_batch_size".into(), "5".into(), ); + config_map.insert("datafusion.execution.batch_size".into(), "1".into()); helper_test_append_new_files_to_table( FileType::PARQUET, FileCompressionType::UNCOMPRESSED, Some(config_map), + 2, ) .await?; Ok(()) @@ -1822,6 +1619,7 @@ mod tests { FileType::PARQUET, FileCompressionType::UNCOMPRESSED, Some(config_map), + 2, ) .await .expect_err("Example should fail!"); @@ -1830,221 +1628,17 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_insert_into_append_to_parquet_file_fails() -> Result<()> { - let maybe_err = helper_test_insert_into_append_to_existing_files( - FileType::PARQUET, - FileCompressionType::UNCOMPRESSED, - None, - ) - .await; - let _err = - maybe_err.expect_err("Appending to existing parquet file did not fail!"); - Ok(()) - } - - fn load_empty_schema_table( - schema: SchemaRef, - temp_path: &str, - insert_mode: ListingTableInsertMode, - file_format: Arc, - ) -> Result> { - File::create(temp_path)?; - let table_path = ListingTableUrl::parse(temp_path).unwrap(); - - let listing_options = - ListingOptions::new(file_format.clone()).with_insert_mode(insert_mode); - - let config = ListingTableConfig::new(table_path) - .with_listing_options(listing_options) - .with_schema(schema); - - let table = ListingTable::try_new(config)?; - Ok(Arc::new(table)) - } - - /// Logic of testing inserting into listing table by Appending to existing files - /// is the same for all formats/options which support this. This helper allows - /// passing different options to execute the same test with different settings. - async fn helper_test_insert_into_append_to_existing_files( - file_type: FileType, - file_compression_type: FileCompressionType, - session_config_map: Option>, - ) -> Result<()> { - // Create the initial context, schema, and batch. - let session_ctx = match session_config_map { - Some(cfg) => { - let config = SessionConfig::from_string_hash_map(cfg)?; - SessionContext::with_config(config) - } - None => SessionContext::new(), - }; - // Create a new schema with one field called "a" of type Int32 - let schema = Arc::new(Schema::new(vec![Field::new( - "column1", - DataType::Int32, - false, - )])); - - // Create a new batch of data to insert into the table - let batch = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))], - )?; - - // Filename with extension - let filename = format!( - "path{}", - file_type - .to_owned() - .get_ext_with_compression(file_compression_type) - .unwrap() - ); - - // Create a temporary directory and a CSV file within it. - let tmp_dir = TempDir::new()?; - let path = tmp_dir.path().join(filename); - - let file_format: Arc = match file_type { - FileType::CSV => Arc::new( - CsvFormat::default().with_file_compression_type(file_compression_type), - ), - FileType::JSON => Arc::new( - JsonFormat::default().with_file_compression_type(file_compression_type), - ), - FileType::PARQUET => Arc::new(ParquetFormat::default()), - FileType::AVRO => Arc::new(AvroFormat {}), - FileType::ARROW => Arc::new(ArrowFormat {}), - }; - - let initial_table = load_empty_schema_table( - schema.clone(), - path.to_str().unwrap(), - ListingTableInsertMode::AppendToFile, - file_format, - )?; - session_ctx.register_table("t", initial_table)?; - // Create and register the source table with the provided schema and inserted data - let source_table = Arc::new(MemTable::try_new( - schema.clone(), - vec![vec![batch.clone(), batch.clone()]], - )?); - session_ctx.register_table("source", source_table.clone())?; - // Convert the source table into a provider so that it can be used in a query - let source = provider_as_source(source_table); - // Create a table scan logical plan to read from the source table - let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; - // Create an insert plan to insert the source data into the initial table - let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; - // Create a physical plan from the insert plan - let plan = session_ctx - .state() - .create_physical_plan(&insert_into_table) - .await?; - - // Execute the physical plan and collect the results - let res = collect(plan, session_ctx.task_ctx()).await?; - // Insert returns the number of rows written, in our case this would be 6. - let expected = [ - "+-------+", - "| count |", - "+-------+", - "| 6 |", - "+-------+", - ]; - - // Assert that the batches read from the file match the expected result. - assert_batches_eq!(expected, &res); - - // Read the records in the table - let batches = session_ctx.sql("select * from t").await?.collect().await?; - - // Define the expected result as a vector of strings. - let expected = [ - "+---------+", - "| column1 |", - "+---------+", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "+---------+", - ]; - - // Assert that the batches read from the file match the expected result. - assert_batches_eq!(expected, &batches); - - // Assert that only 1 file was added to the table - let num_files = tmp_dir.path().read_dir()?.count(); - assert_eq!(num_files, 1); - - // Create a physical plan from the insert plan - let plan = session_ctx - .state() - .create_physical_plan(&insert_into_table) - .await?; - - // Again, execute the physical plan and collect the results - let res = collect(plan, session_ctx.task_ctx()).await?; - // Insert returns the number of rows written, in our case this would be 6. - let expected = [ - "+-------+", - "| count |", - "+-------+", - "| 6 |", - "+-------+", - ]; - - // Assert that the batches read from the file match the expected result. - assert_batches_eq!(expected, &res); - - // Open the CSV file, read its contents as a record batch, and collect the batches into a vector. - let batches = session_ctx.sql("select * from t").await?.collect().await?; - - // Define the expected result after the second append. - let expected = vec![ - "+---------+", - "| column1 |", - "+---------+", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "+---------+", - ]; - - // Assert that the batches read from the file after the second append match the expected result. - assert_batches_eq!(expected, &batches); - - // Assert that no additional files were added to the table - let num_files = tmp_dir.path().read_dir()?.count(); - assert_eq!(num_files, 1); - - // Return Ok if the function - Ok(()) - } - async fn helper_test_append_new_files_to_table( file_type: FileType, file_compression_type: FileCompressionType, session_config_map: Option>, + expected_n_files_per_insert: usize, ) -> Result<()> { // Create the initial context, schema, and batch. let session_ctx = match session_config_map { Some(cfg) => { let config = SessionConfig::from_string_hash_map(cfg)?; - SessionContext::with_config(config) + SessionContext::new_with_config(config) } None => SessionContext::new(), }; @@ -2056,10 +1650,18 @@ mod tests { false, )])); + let filter_predicate = Expr::BinaryExpr(BinaryExpr::new( + Box::new(Expr::Column("column1".into())), + Operator::GtEq, + Box::new(Expr::Literal(ScalarValue::Int32(Some(0)))), + )); + // Create a new batch of data to insert into the table let batch = RecordBatch::try_new( schema.clone(), - vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))], + vec![Arc::new(arrow_array::Int32Array::from(vec![ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + ]))], )?; // Register appropriate table depending on file_type we want to test @@ -2071,7 +1673,6 @@ mod tests { "t", tmp_dir.path().to_str().unwrap(), CsvReadOptions::new() - .insert_mode(ListingTableInsertMode::AppendNewFiles) .schema(schema.as_ref()) .file_compression_type(file_compression_type), ) @@ -2083,7 +1684,6 @@ mod tests { "t", tmp_dir.path().to_str().unwrap(), NdJsonReadOptions::default() - .insert_mode(ListingTableInsertMode::AppendNewFiles) .schema(schema.as_ref()) .file_compression_type(file_compression_type), ) @@ -2094,9 +1694,7 @@ mod tests { .register_parquet( "t", tmp_dir.path().to_str().unwrap(), - ParquetReadOptions::default() - .insert_mode(ListingTableInsertMode::AppendNewFiles) - .schema(schema.as_ref()), + ParquetReadOptions::default().schema(schema.as_ref()), ) .await?; } @@ -2105,10 +1703,7 @@ mod tests { .register_avro( "t", tmp_dir.path().to_str().unwrap(), - AvroReadOptions::default() - // TODO implement insert_mode for avro - //.insert_mode(ListingTableInsertMode::AppendNewFiles) - .schema(schema.as_ref()), + AvroReadOptions::default().schema(schema.as_ref()), ) .await?; } @@ -2117,10 +1712,7 @@ mod tests { .register_arrow( "t", tmp_dir.path().to_str().unwrap(), - ArrowReadOptions::default() - // TODO implement insert_mode for arrow - //.insert_mode(ListingTableInsertMode::AppendNewFiles) - .schema(schema.as_ref()), + ArrowReadOptions::default().schema(schema.as_ref()), ) .await?; } @@ -2136,8 +1728,10 @@ mod tests { let source = provider_as_source(source_table); // Create a table scan logical plan to read from the source table let scan_plan = LogicalPlanBuilder::scan("source", source, None)? - .repartition(Partitioning::Hash(vec![Expr::Column("column1".into())], 6))? + .filter(filter_predicate)? .build()?; + // Since logical plan contains a filter, increasing parallelism is helpful. + // Therefore, we will have 8 partitions in the final plan. // Create an insert plan to insert the source data into the initial table let insert_into_table = LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; @@ -2146,7 +1740,6 @@ mod tests { .state() .create_physical_plan(&insert_into_table) .await?; - // Execute the physical plan and collect the results let res = collect(plan, session_ctx.task_ctx()).await?; // Insert returns the number of rows written, in our case this would be 6. @@ -2154,7 +1747,7 @@ mod tests { "+-------+", "| count |", "+-------+", - "| 6 |", + "| 20 |", "+-------+", ]; @@ -2171,16 +1764,16 @@ mod tests { "+-------+", "| count |", "+-------+", - "| 6 |", + "| 20 |", "+-------+", ]; // Assert that the batches read from the file match the expected result. assert_batches_eq!(expected, &batches); - // Assert that 6 files were added to the table + // Assert that `target_partition_number` many files were added to the table. let num_files = tmp_dir.path().read_dir()?.count(); - assert_eq!(num_files, 6); + assert_eq!(num_files, expected_n_files_per_insert); // Create a physical plan from the insert plan let plan = session_ctx @@ -2195,7 +1788,7 @@ mod tests { "+-------+", "| count |", "+-------+", - "| 6 |", + "| 20 |", "+-------+", ]; @@ -2214,16 +1807,16 @@ mod tests { "+-------+", "| count |", "+-------+", - "| 12 |", + "| 40 |", "+-------+", ]; // Assert that the batches read from the file after the second append match the expected result. assert_batches_eq!(expected, &batches); - // Assert that another 6 files were added to the table + // Assert that another `target_partition_number` many files were added to the table. let num_files = tmp_dir.path().read_dir()?.count(); - assert_eq!(num_files, 12); + assert_eq!(num_files, expected_n_files_per_insert * 2); // Return Ok if the function Ok(()) @@ -2242,7 +1835,7 @@ mod tests { let session_ctx = match session_config_map { Some(cfg) => { let config = SessionConfig::from_string_hash_map(cfg)?; - SessionContext::with_config(config) + SessionContext::new_with_config(config) } None => SessionContext::new(), }; diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index 96998de17b5d..766dee7de901 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -18,14 +18,17 @@ use std::fs; use crate::datasource::object_store::ObjectStoreUrl; +use crate::execution::context::SessionState; use datafusion_common::{DataFusionError, Result}; +use datafusion_optimizer::OptimizerConfig; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; use glob::Pattern; use itertools::Itertools; +use log::debug; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; -use percent_encoding; +use std::sync::Arc; use url::Url; /// A parsed URL identifying files for a listing table, see [`ListingTableUrl::parse`] @@ -43,22 +46,45 @@ pub struct ListingTableUrl { impl ListingTableUrl { /// Parse a provided string as a `ListingTableUrl` /// + /// A URL can either refer to a single object, or a collection of objects with a + /// common prefix, with the presence of a trailing `/` indicating a collection. + /// + /// For example, `file:///foo.txt` refers to the file at `/foo.txt`, whereas + /// `file:///foo/` refers to all the files under the directory `/foo` and its + /// subdirectories. + /// + /// Similarly `s3://BUCKET/blob.csv` refers to `blob.csv` in the S3 bucket `BUCKET`, + /// wherease `s3://BUCKET/foo/` refers to all objects with the prefix `foo/` in the + /// S3 bucket `BUCKET` + /// + /// # URL Encoding + /// + /// URL paths are expected to be URL-encoded. That is, the URL for a file named `bar%2Efoo` + /// would be `file:///bar%252Efoo`, as per the [URL] specification. + /// + /// It should be noted that some tools, such as the AWS CLI, take a different approach and + /// instead interpret the URL path verbatim. For example the object `bar%2Efoo` would be + /// addressed as `s3://BUCKET/bar%252Efoo` using [`ListingTableUrl`] but `s3://BUCKET/bar%2Efoo` + /// when using the aws-cli. + /// /// # Paths without a Scheme /// /// If no scheme is provided, or the string is an absolute filesystem path - /// as determined [`std::path::Path::is_absolute`], the string will be + /// as determined by [`std::path::Path::is_absolute`], the string will be /// interpreted as a path on the local filesystem using the operating /// system's standard path delimiter, i.e. `\` on Windows, `/` on Unix. /// /// If the path contains any of `'?', '*', '['`, it will be considered /// a glob expression and resolved as described in the section below. /// - /// Otherwise, the path will be resolved to an absolute path, returning - /// an error if it does not exist, and converted to a [file URI] + /// Otherwise, the path will be resolved to an absolute path based on the current + /// working directory, and converted to a [file URI]. /// - /// If you wish to specify a path that does not exist on the local - /// machine you must provide it as a fully-qualified [file URI] - /// e.g. `file:///myfile.txt` + /// If the path already exists in the local filesystem this will be used to determine if this + /// [`ListingTableUrl`] refers to a collection or a single object, otherwise the presence + /// of a trailing path delimiter will be used to indicate a directory. For the avoidance + /// of ambiguity it is recommended users always include trailing `/` when intending to + /// refer to a directory. /// /// ## Glob File Paths /// @@ -66,14 +92,13 @@ impl ListingTableUrl { /// be resolved as follows. /// /// The string up to the first path segment containing a glob expression will be extracted, - /// and resolved in the same manner as a normal scheme-less path. That is, resolved to - /// an absolute path on the local filesystem, returning an error if it does not exist, - /// and converted to a [file URI] + /// and resolved in the same manner as a normal scheme-less path above. /// /// The remaining string will be interpreted as a [`glob::Pattern`] and used as a /// filter when listing files from object storage /// /// [file URI]: https://en.wikipedia.org/wiki/File_URI_scheme + /// [URL]: https://url.spec.whatwg.org/ pub fn parse(s: impl AsRef) -> Result { let s = s.as_ref(); @@ -83,7 +108,7 @@ impl ListingTableUrl { } match Url::parse(s) { - Ok(url) => Ok(Self::new(url, None)), + Ok(url) => Self::try_new(url, None), Err(url::ParseError::RelativeUrlWithoutBase) => Self::parse_path(s), Err(e) => Err(DataFusionError::External(Box::new(e))), } @@ -92,6 +117,7 @@ impl ListingTableUrl { /// Get object store for specified input_url /// if input_url is actually not a url, we assume it is a local file path /// if we have a local path, create it if not exists so ListingTableUrl::parse works + #[deprecated(note = "Use parse")] pub fn parse_create_local_if_not_exists( s: impl AsRef, is_directory: bool, @@ -107,6 +133,10 @@ impl ListingTableUrl { if is_directory { fs::create_dir_all(path)?; } else { + // ensure parent directory exists + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } fs::File::create(path)?; } } @@ -117,7 +147,7 @@ impl ListingTableUrl { /// Creates a new [`ListingTableUrl`] interpreting `s` as a filesystem path fn parse_path(s: &str) -> Result { - let (prefix, glob) = match split_glob_expression(s) { + let (path, glob) = match split_glob_expression(s) { Some((prefix, glob)) => { let glob = Pattern::new(glob) .map_err(|e| DataFusionError::External(Box::new(e)))?; @@ -126,24 +156,19 @@ impl ListingTableUrl { None => (s, None), }; - let path = std::path::Path::new(prefix).canonicalize()?; - let url = if path.is_dir() { - Url::from_directory_path(path) - } else { - Url::from_file_path(path) - } - .map_err(|_| DataFusionError::Internal(format!("Can not open path: {s}")))?; - // TODO: Currently we do not have an IO-related error variant that accepts () - // or a string. Once we have such a variant, change the error type above. - Ok(Self::new(url, glob)) + let url = url_from_filesystem_path(path).ok_or_else(|| { + DataFusionError::External( + format!("Failed to convert path to URL: {path}").into(), + ) + })?; + + Self::try_new(url, glob) } /// Creates a new [`ListingTableUrl`] from a url and optional glob expression - fn new(url: Url, glob: Option) -> Self { - let decoded_path = - percent_encoding::percent_decode_str(url.path()).decode_utf8_lossy(); - let prefix = Path::from(decoded_path.as_ref()); - Self { url, prefix, glob } + fn try_new(url: Url, glob: Option) -> Result { + let prefix = Path::from_url_path(url.path())?; + Ok(Self { url, prefix, glob }) } /// Returns the URL scheme @@ -151,25 +176,46 @@ impl ListingTableUrl { self.url.scheme() } - /// Return the prefix from which to list files + /// Return the URL path not excluding any glob expression + /// + /// If [`Self::is_collection`], this is the listing prefix + /// Otherwise, this is the path to the object pub fn prefix(&self) -> &Path { &self.prefix } /// Returns `true` if `path` matches this [`ListingTableUrl`] - pub fn contains(&self, path: &Path) -> bool { + pub fn contains(&self, path: &Path, ignore_subdirectory: bool) -> bool { match self.strip_prefix(path) { Some(mut segments) => match &self.glob { Some(glob) => { - let stripped = segments.join("/"); - glob.matches(&stripped) + if ignore_subdirectory { + segments + .next() + .map_or(false, |file_name| glob.matches(file_name)) + } else { + let stripped = segments.join("/"); + glob.matches(&stripped) + } + } + None => { + if ignore_subdirectory { + let has_subdirectory = segments.collect::>().len() > 1; + !has_subdirectory + } else { + true + } } - None => true, }, None => false, } } + /// Returns `true` if `path` refers to a collection of objects + pub fn is_collection(&self) -> bool { + self.url.as_str().ends_with('/') + } + /// Strips the prefix of this [`ListingTableUrl`] from the provided path, returning /// an iterator of the remaining path segments pub(crate) fn strip_prefix<'a, 'b: 'a>( @@ -185,28 +231,42 @@ impl ListingTableUrl { } /// List all files identified by this [`ListingTableUrl`] for the provided `file_extension` - pub(crate) fn list_all_files<'a>( + pub(crate) async fn list_all_files<'a>( &'a self, + ctx: &'a SessionState, store: &'a dyn ObjectStore, file_extension: &'a str, - ) -> BoxStream<'a, Result> { + ) -> Result>> { + let exec_options = &ctx.options().execution; + let ignore_subdirectory = exec_options.listing_table_ignore_subdirectory; // If the prefix is a file, use a head request, otherwise list - let is_dir = self.url.as_str().ends_with('/'); - let list = match is_dir { - true => futures::stream::once(store.list(Some(&self.prefix))) - .try_flatten() - .boxed(), + let list = match self.is_collection() { + true => match ctx.runtime_env().cache_manager.get_list_files_cache() { + None => store.list(Some(&self.prefix)), + Some(cache) => { + if let Some(res) = cache.get(&self.prefix) { + debug!("Hit list all files cache"); + futures::stream::iter(res.as_ref().clone().into_iter().map(Ok)) + .boxed() + } else { + let list_res = store.list(Some(&self.prefix)); + let vec = list_res.try_collect::>().await?; + cache.put(&self.prefix, Arc::new(vec.clone())); + futures::stream::iter(vec.into_iter().map(Ok)).boxed() + } + } + }, false => futures::stream::once(store.head(&self.prefix)).boxed(), }; - - list.map_err(Into::into) + Ok(list .try_filter(move |meta| { let path = &meta.location; let extension_match = path.as_ref().ends_with(file_extension); - let glob_match = self.contains(path); + let glob_match = self.contains(path, ignore_subdirectory); futures::future::ready(extension_match && glob_match) }) - .boxed() + .map_err(DataFusionError::ObjectStore) + .boxed()) } /// Returns this [`ListingTableUrl`] as a string @@ -221,6 +281,34 @@ impl ListingTableUrl { } } +/// Creates a file URL from a potentially relative filesystem path +fn url_from_filesystem_path(s: &str) -> Option { + let path = std::path::Path::new(s); + let is_dir = match path.exists() { + true => path.is_dir(), + // Fallback to inferring from trailing separator + false => std::path::is_separator(s.chars().last()?), + }; + + let from_absolute_path = |p| { + let first = match is_dir { + true => Url::from_directory_path(p).ok(), + false => Url::from_file_path(p).ok(), + }?; + + // By default from_*_path preserve relative path segments + // We therefore parse the URL again to resolve these + Url::parse(first.as_str()).ok() + }; + + if path.is_absolute() { + return from_absolute_path(path); + } + + let absolute = std::env::current_dir().ok()?.join(path); + from_absolute_path(&absolute) +} + impl AsRef for ListingTableUrl { fn as_ref(&self) -> &str { self.url.as_ref() @@ -268,6 +356,7 @@ fn split_glob_expression(path: &str) -> Option<(&str, &str)> { #[cfg(test)] mod tests { use super::*; + use tempfile::tempdir; #[test] fn test_prefix_path() { @@ -300,7 +389,57 @@ mod tests { assert_eq!(url.prefix.as_ref(), "foo/bar"); let url = ListingTableUrl::parse("file:///foo/😺").unwrap(); - assert_eq!(url.prefix.as_ref(), "foo/%F0%9F%98%BA"); + assert_eq!(url.prefix.as_ref(), "foo/😺"); + + let url = ListingTableUrl::parse("file:///foo/bar%2Efoo").unwrap(); + assert_eq!(url.prefix.as_ref(), "foo/bar.foo"); + + let url = ListingTableUrl::parse("file:///foo/bar%2Efoo").unwrap(); + assert_eq!(url.prefix.as_ref(), "foo/bar.foo"); + + let url = ListingTableUrl::parse("file:///foo/bar%252Ffoo").unwrap(); + assert_eq!(url.prefix.as_ref(), "foo/bar%2Ffoo"); + + let url = ListingTableUrl::parse("file:///foo/a%252Fb.txt").unwrap(); + assert_eq!(url.prefix.as_ref(), "foo/a%2Fb.txt"); + + let dir = tempdir().unwrap(); + let path = dir.path().join("bar%2Ffoo"); + std::fs::File::create(&path).unwrap(); + + let url = ListingTableUrl::parse(path.to_str().unwrap()).unwrap(); + assert!(url.prefix.as_ref().ends_with("bar%2Ffoo"), "{}", url.prefix); + + let url = ListingTableUrl::parse("file:///foo/../a%252Fb.txt").unwrap(); + assert_eq!(url.prefix.as_ref(), "a%2Fb.txt"); + + let url = + ListingTableUrl::parse("file:///foo/./bar/../../baz/./test.txt").unwrap(); + assert_eq!(url.prefix.as_ref(), "baz/test.txt"); + + let workdir = std::env::current_dir().unwrap(); + let t = workdir.join("non-existent"); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("non-existent").unwrap(); + assert_eq!(a, b); + assert!(a.prefix.as_ref().ends_with("non-existent")); + + let t = workdir.parent().unwrap(); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("..").unwrap(); + assert_eq!(a, b); + + let t = t.join("bar"); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("../bar").unwrap(); + assert_eq!(a, b); + assert!(a.prefix.as_ref().ends_with("bar")); + + let t = t.join(".").join("foo").join("..").join("baz"); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("../bar/./foo/../baz").unwrap(); + assert_eq!(a, b); + assert!(a.prefix.as_ref().ends_with("bar/baz")); } #[test] diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 9c438a47943f..e8ffece320d7 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -21,42 +21,34 @@ use std::path::Path; use std::str::FromStr; use std::sync::Arc; -use arrow::datatypes::{DataType, SchemaRef}; -use async_trait::async_trait; -use datafusion_common::file_options::{FileTypeWriterOptions, StatementOptions}; -use datafusion_common::DataFusionError; -use datafusion_expr::CreateExternalTable; - -use crate::datasource::file_format::arrow::ArrowFormat; -use crate::datasource::file_format::avro::AvroFormat; -use crate::datasource::file_format::csv::CsvFormat; -use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::file_format::json::JsonFormat; +#[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; -use crate::datasource::file_format::FileFormat; +use crate::datasource::file_format::{ + arrow::ArrowFormat, avro::AvroFormat, csv::CsvFormat, + file_compression_type::FileCompressionType, json::JsonFormat, FileFormat, +}; use crate::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; use crate::datasource::provider::TableProviderFactory; use crate::datasource::TableProvider; use crate::execution::context::SessionState; -use datafusion_common::FileType; -use super::listing::ListingTableInsertMode; +use arrow::datatypes::{DataType, SchemaRef}; +use datafusion_common::file_options::{FileTypeWriterOptions, StatementOptions}; +use datafusion_common::{arrow_datafusion_err, plan_err, DataFusionError, FileType}; +use datafusion_expr::CreateExternalTable; + +use async_trait::async_trait; /// A `TableProviderFactory` capable of creating new `ListingTable`s +#[derive(Debug, Default)] pub struct ListingTableFactory {} impl ListingTableFactory { /// Creates a new `ListingTableFactory` pub fn new() -> Self { - Self {} - } -} - -impl Default for ListingTableFactory { - fn default() -> Self { - Self::new() + Self::default() } } @@ -75,12 +67,21 @@ impl TableProviderFactory for ListingTableFactory { let file_extension = get_extension(cmd.location.as_str()); let file_format: Arc = match file_type { - FileType::CSV => Arc::new( - CsvFormat::default() + FileType::CSV => { + let mut statement_options = StatementOptions::from(&cmd.options); + let mut csv_format = CsvFormat::default() .with_has_header(cmd.has_header) .with_delimiter(cmd.delimiter as u8) - .with_file_compression_type(file_compression_type), - ), + .with_file_compression_type(file_compression_type); + if let Some(quote) = statement_options.take_str_option("quote") { + csv_format = csv_format.with_quote(quote.as_bytes()[0]) + } + if let Some(escape) = statement_options.take_str_option("escape") { + csv_format = csv_format.with_escape(Some(escape.as_bytes()[0])) + } + Arc::new(csv_format) + } + #[cfg(feature = "parquet")] FileType::PARQUET => Arc::new(ParquetFormat::default()), FileType::AVRO => Arc::new(AvroFormat), FileType::JSON => Arc::new( @@ -113,7 +114,7 @@ impl TableProviderFactory for ListingTableFactory { .map(|col| { schema .field_with_name(col) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect::>>()? .into_iter() @@ -132,39 +133,17 @@ impl TableProviderFactory for ListingTableFactory { (Some(schema), table_partition_cols) }; - // look for 'infinite' as an option - let infinite_source = cmd.unbounded; - let mut statement_options = StatementOptions::from(&cmd.options); - // Extract ListingTable specific options if present or set default - let unbounded = if infinite_source { - statement_options.take_str_option("unbounded"); - infinite_source - } else { - statement_options - .take_bool_option("unbounded")? - .unwrap_or(false) - }; - - let create_local_path = statement_options - .take_bool_option("create_local_path")? - .unwrap_or(false); - let single_file = statement_options - .take_bool_option("single_file")? - .unwrap_or(false); - - let explicit_insert_mode = statement_options.take_str_option("insert_mode"); - let insert_mode = match explicit_insert_mode { - Some(mode) => ListingTableInsertMode::from_str(mode.as_str()), - None => match file_type { - FileType::CSV => Ok(ListingTableInsertMode::AppendToFile), - FileType::PARQUET => Ok(ListingTableInsertMode::AppendNewFiles), - FileType::AVRO => Ok(ListingTableInsertMode::AppendNewFiles), - FileType::JSON => Ok(ListingTableInsertMode::AppendToFile), - FileType::ARROW => Ok(ListingTableInsertMode::AppendNewFiles), - }, - }?; + // Backwards compatibility (#8547), discard deprecated options + statement_options.take_bool_option("single_file")?; + if let Some(s) = statement_options.take_str_option("insert_mode") { + if !s.eq_ignore_ascii_case("append_new_files") { + return plan_err!("Unknown or unsupported insert mode {s}. Only append_new_files supported"); + } + } + statement_options.take_bool_option("create_local_path")?; + statement_options.take_str_option("unbounded"); let file_type = file_format.file_type(); @@ -181,10 +160,9 @@ impl TableProviderFactory for ListingTableFactory { FileType::CSV => { let mut csv_writer_options = file_type_writer_options.try_into_csv()?.clone(); - csv_writer_options.has_header = cmd.has_header; csv_writer_options.writer_options = csv_writer_options .writer_options - .has_headers(cmd.has_header) + .with_header(cmd.has_header) .with_delimiter(cmd.delimiter.try_into().map_err(|_| { DataFusionError::Internal( "Unable to convert CSV delimiter into u8".into(), @@ -199,18 +177,13 @@ impl TableProviderFactory for ListingTableFactory { json_writer_options.compression = cmd.file_compression_type; FileTypeWriterOptions::JSON(json_writer_options) } + #[cfg(feature = "parquet")] FileType::PARQUET => file_type_writer_options, FileType::ARROW => file_type_writer_options, FileType::AVRO => file_type_writer_options, }; - let table_path = match create_local_path { - true => ListingTableUrl::parse_create_local_if_not_exists( - &cmd.location, - !single_file, - ), - false => ListingTableUrl::parse(&cmd.location), - }?; + let table_path = ListingTableUrl::parse(&cmd.location)?; let options = ListingOptions::new(file_format) .with_collect_stat(state.config().collect_statistics()) @@ -218,10 +191,7 @@ impl TableProviderFactory for ListingTableFactory { .with_target_partitions(state.config().target_partitions()) .with_table_partition_cols(table_partition_cols) .with_file_sort_order(cmd.order_exprs.clone()) - .with_insert_mode(insert_mode) - .with_single_file(single_file) - .with_write_options(file_type_writer_options) - .with_infinite_source(unbounded); + .with_write_options(file_type_writer_options); let resolved_schema = match provided_schema { None => options.infer_schema(state, &table_path).await?, @@ -232,7 +202,10 @@ impl TableProviderFactory for ListingTableFactory { .with_schema(resolved_schema); let provider = ListingTable::try_new(config)? .with_cache(state.runtime_env().cache_manager.get_file_statistic_cache()); - let table = provider.with_definition(cmd.definition.clone()); + let table = provider + .with_definition(cmd.definition.clone()) + .with_constraints(cmd.constraints.clone()) + .with_column_defaults(cmd.column_defaults.clone()); Ok(Arc::new(table)) } } @@ -248,13 +221,13 @@ fn get_extension(path: &str) -> String { #[cfg(test)] mod tests { - use super::*; - use std::collections::HashMap; + use super::*; use crate::execution::context::SessionContext; + use datafusion_common::parsers::CompressionTypeVariant; - use datafusion_common::{DFSchema, OwnedTableReference}; + use datafusion_common::{Constraints, DFSchema, OwnedTableReference}; #[tokio::test] async fn test_create_using_non_std_file_ext() { @@ -282,6 +255,8 @@ mod tests { order_exprs: vec![], unbounded: false, options: HashMap::new(), + constraints: Constraints::empty(), + column_defaults: HashMap::new(), }; let table_provider = factory.create(&state, &cmd).await.unwrap(); let listing_table = table_provider diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 337a8cabc269..7c61cc536860 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -17,9 +17,11 @@ //! [`MemTable`] for querying `Vec` by DataFusion. +use datafusion_physical_plan::metrics::MetricsSet; use futures::StreamExt; use log::debug; use std::any::Any; +use std::collections::HashMap; use std::fmt::{self, Debug}; use std::sync::Arc; @@ -54,7 +56,8 @@ pub type PartitionData = Arc>>; pub struct MemTable { schema: SchemaRef, pub(crate) batches: Vec, - constraints: Option, + constraints: Constraints, + column_defaults: HashMap, } impl MemTable { @@ -77,15 +80,23 @@ impl MemTable { .into_iter() .map(|e| Arc::new(RwLock::new(e))) .collect::>(), - constraints: None, + constraints: Constraints::empty(), + column_defaults: HashMap::new(), }) } /// Assign constraints pub fn with_constraints(mut self, constraints: Constraints) -> Self { - if !constraints.is_empty() { - self.constraints = Some(constraints); - } + self.constraints = constraints; + self + } + + /// Assign column defaults + pub fn with_column_defaults( + mut self, + column_defaults: HashMap, + ) -> Self { + self.column_defaults = column_defaults; self } @@ -164,7 +175,7 @@ impl TableProvider for MemTable { } fn constraints(&self) -> Option<&Constraints> { - self.constraints.as_ref() + Some(&self.constraints) } fn table_type(&self) -> TableType { @@ -210,7 +221,10 @@ impl TableProvider for MemTable { ) -> Result> { // Create a physical plan from the logical plan. // Check that the schema of the plan matches the schema of this table. - if !self.schema().equivalent_names_and_types(&input.schema()) { + if !self + .schema() + .logically_equivalent_names_and_types(&input.schema()) + { return plan_err!( "Inserting query must have the same schema with the table." ); @@ -223,8 +237,13 @@ impl TableProvider for MemTable { input, sink, self.schema.clone(), + None, ))) } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.column_defaults.get(column) + } } /// Implements for writing to a [`MemTable`] @@ -260,9 +279,17 @@ impl MemSink { #[async_trait] impl DataSink for MemSink { + fn as_any(&self) -> &dyn Any { + self + } + + fn metrics(&self) -> Option { + None + } + async fn write_all( &self, - mut data: Vec, + mut data: SendableRecordBatchStream, _context: &Arc, ) -> Result { let num_partitions = self.batches.len(); @@ -272,14 +299,10 @@ impl DataSink for MemSink { let mut new_batches = vec![vec![]; num_partitions]; let mut i = 0; let mut row_count = 0; - let num_parts = data.len(); - // TODO parallelize outer and inner loops - for data_part in data.iter_mut().take(num_parts) { - while let Some(batch) = data_part.next().await.transpose()? { - row_count += batch.num_rows(); - new_batches[i].push(batch); - i = (i + 1) % num_partitions; - } + while let Some(batch) = data.next().await.transpose()? { + row_count += batch.num_rows(); + new_batches[i].push(batch); + i = (i + 1) % num_partitions; } // write the outputs into the batches @@ -400,7 +423,7 @@ mod tests { .scan(&session_ctx.state(), Some(&projection), &[], None) .await { - Err(DataFusionError::ArrowError(ArrowError::SchemaError(e))) => { + Err(DataFusionError::ArrowError(ArrowError::SchemaError(e), _)) => { assert_eq!( "\"project index 4 out of bounds, max field 3\"", format!("{e:?}") diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 35f56536510c..2e516cc36a01 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -23,12 +23,14 @@ pub mod avro_to_arrow; pub mod default_table_source; pub mod empty; pub mod file_format; +pub mod function; pub mod listing; pub mod listing_table_factory; pub mod memory; pub mod physical_plan; pub mod provider; mod statistics; +pub mod stream; pub mod streaming; pub mod view; @@ -43,4 +45,46 @@ pub use self::provider::TableProvider; pub use self::view::ViewTable; pub use crate::logical_expr::TableType; pub use statistics::get_statistics_with_limit; -pub(crate) use statistics::{create_max_min_accs, get_col_stats}; + +use arrow_schema::{Schema, SortOptions}; +use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_expr::Expr; +use datafusion_physical_expr::{expressions, LexOrdering, PhysicalSortExpr}; + +fn create_ordering( + schema: &Schema, + sort_order: &[Vec], +) -> Result> { + let mut all_sort_orders = vec![]; + + for exprs in sort_order { + // Construct PhysicalSortExpr objects from Expr objects: + let mut sort_exprs = vec![]; + for expr in exprs { + match expr { + Expr::Sort(sort) => match sort.expr.as_ref() { + Expr::Column(col) => match expressions::col(&col.name, schema) { + Ok(expr) => { + sort_exprs.push(PhysicalSortExpr { + expr, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); + } + // Cannot find expression in the projected_schema, stop iterating + // since rest of the orderings are violated + Err(_) => break, + } + expr => return plan_err!("Expected single column references in output_ordering, got {expr}"), + } + expr => return plan_err!("Expected Expr::Sort in output_ordering, but got {expr}"), + } + } + if !sort_exprs.is_empty() { + all_sort_orders.push(sort_exprs); + } + } + Ok(all_sort_orders) +} diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index a47376248ed3..ae1e879d0da1 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -16,6 +16,10 @@ // under the License. //! Execution plan for reading Arrow files + +use std::any::Any; +use std::sync::Arc; + use crate::datasource::physical_plan::{ FileMeta, FileOpenFuture, FileOpener, FileScanConfig, }; @@ -24,17 +28,14 @@ use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, }; + use arrow_schema::SchemaRef; use datafusion_common::Statistics; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{ - ordering_equivalence_properties_helper, LexOrdering, OrderingEquivalenceProperties, - PhysicalSortExpr, -}; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalSortExpr}; + use futures::StreamExt; use object_store::{GetResultPayload, ObjectStore}; -use std::any::Any; -use std::sync::Arc; /// Execution plan for scanning Arrow data source #[derive(Debug, Clone)] @@ -92,18 +93,14 @@ impl ExecutionPlan for ArrowExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config().infinite_source) - } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering .first() .map(|ordering| ordering.as_slice()) } - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - ordering_equivalence_properties_helper( + fn equivalence_properties(&self) -> EquivalenceProperties { + EquivalenceProperties::new_with_orderings( self.schema(), &self.projected_output_ordering, ) @@ -143,8 +140,8 @@ impl ExecutionPlan for ArrowExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - self.projected_statistics.clone() + fn statistics(&self) -> Result { + Ok(self.projected_statistics.clone()) } } diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index 93655e8665f0..e448bf39f427 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -16,6 +16,11 @@ // under the License. //! Execution plan for reading line-delimited Avro files + +use std::any::Any; +use std::sync::Arc; + +use super::FileScanConfig; use crate::error::Result; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; @@ -23,17 +28,10 @@ use crate::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; -use datafusion_execution::TaskContext; use arrow::datatypes::SchemaRef; -use datafusion_physical_expr::{ - ordering_equivalence_properties_helper, LexOrdering, OrderingEquivalenceProperties, -}; - -use std::any::Any; -use std::sync::Arc; - -use super::FileScanConfig; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; /// Execution plan for scanning Avro data source #[derive(Debug, Clone)] @@ -91,18 +89,14 @@ impl ExecutionPlan for AvroExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config().infinite_source) - } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering .first() .map(|ordering| ordering.as_slice()) } - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - ordering_equivalence_properties_helper( + fn equivalence_properties(&self) -> EquivalenceProperties { + EquivalenceProperties::new_with_orderings( self.schema(), &self.projected_output_ordering, ) @@ -154,8 +148,8 @@ impl ExecutionPlan for AvroExec { Ok(Box::pin(stream)) } - fn statistics(&self) -> Statistics { - self.projected_statistics.clone() + fn statistics(&self) -> Result { + Ok(self.projected_statistics.clone()) } fn metrics(&self) -> Option { @@ -272,13 +266,12 @@ mod tests { let avro_exec = AvroExec::new(FileScanConfig { object_store_url: ObjectStoreUrl::local_filesystem(), file_groups: vec![vec![meta.into()]], + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: Some(vec![0, 1, 2]), limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }); assert_eq!(avro_exec.output_partitioning().partition_count(), 1); let mut results = avro_exec @@ -344,13 +337,12 @@ mod tests { let avro_exec = AvroExec::new(FileScanConfig { object_store_url, file_groups: vec![vec![meta.into()]], + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection, limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }); assert_eq!(avro_exec.output_partitioning().partition_count(), 1); @@ -408,8 +400,7 @@ mod tests { .await?; let mut partitioned_file = PartitionedFile::from(meta); - partitioned_file.partition_values = - vec![ScalarValue::Utf8(Some("2021-10-26".to_owned()))]; + partitioned_file.partition_values = vec![ScalarValue::from("2021-10-26")]; let avro_exec = AvroExec::new(FileScanConfig { // select specific columns of the files as well as the partitioning @@ -417,12 +408,11 @@ mod tests { projection: Some(vec![0, 1, file_schema.fields().len(), 2]), object_store_url, file_groups: vec![vec![partitioned_file]], + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), limit: None, - table_partition_cols: vec![("date".to_owned(), DataType::Utf8)], + table_partition_cols: vec![Field::new("date", DataType::Utf8, false)], output_ordering: vec![], - infinite_source: false, }); assert_eq!(avro_exec.output_partitioning().partition_count(), 1); diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index dfc6acdde073..b28bc7d56688 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -17,6 +17,12 @@ //! Execution plan for reading CSV files +use std::any::Any; +use std::io::{Read, Seek, SeekFrom}; +use std::sync::Arc; +use std::task::Poll; + +use super::{calculate_range, FileGroupPartitioner, FileScanConfig, RangeCalculation}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::listing::{FileRange, ListingTableUrl}; use crate::datasource::physical_plan::file_stream::{ @@ -30,25 +36,17 @@ use crate::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; + use arrow::csv; use arrow::datatypes::SchemaRef; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{ - ordering_equivalence_properties_helper, LexOrdering, OrderingEquivalenceProperties, -}; -use tokio::io::AsyncWriteExt; - -use super::FileScanConfig; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use bytes::{Buf, Bytes}; -use futures::ready; -use futures::{StreamExt, TryStreamExt}; +use datafusion_common::config::ConfigOptions; +use futures::{ready, StreamExt, TryStreamExt}; use object_store::{GetOptions, GetResultPayload, ObjectStore}; -use std::any::Any; -use std::io::{Read, Seek, SeekFrom}; -use std::ops::Range; -use std::sync::Arc; -use std::task::Poll; +use tokio::io::AsyncWriteExt; use tokio::task::JoinSet; /// Execution plan for scanning a CSV file @@ -117,34 +115,6 @@ impl CsvExec { pub fn escape(&self) -> Option { self.escape } - - /// Redistribute files across partitions according to their size - /// See comments on `repartition_file_groups()` for more detail. - /// - /// Return `None` if can't get repartitioned(empty/compressed file). - pub fn get_repartitioned( - &self, - target_partitions: usize, - repartition_file_min_size: usize, - ) -> Option { - // Parallel execution on compressed CSV file is not supported yet. - if self.file_compression_type.is_compressed() { - return None; - } - - let repartitioned_file_groups_option = FileScanConfig::repartition_file_groups( - self.base_config.file_groups.clone(), - target_partitions, - repartition_file_min_size, - ); - - if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { - let mut new_plan = self.clone(); - new_plan.base_config.file_groups = repartitioned_file_groups; - return Some(new_plan); - } - None - } } impl DisplayAs for CsvExec { @@ -175,10 +145,6 @@ impl ExecutionPlan for CsvExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config().infinite_source) - } - /// See comments on `impl ExecutionPlan for ParquetExec`: output order can't be fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering @@ -186,8 +152,8 @@ impl ExecutionPlan for CsvExec { .map(|ordering| ordering.as_slice()) } - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - ordering_equivalence_properties_helper( + fn equivalence_properties(&self) -> EquivalenceProperties { + EquivalenceProperties::new_with_orderings( self.schema(), &self.projected_output_ordering, ) @@ -205,6 +171,35 @@ impl ExecutionPlan for CsvExec { Ok(self) } + /// Redistribute files across partitions according to their size + /// See comments on [`FileGroupPartitioner`] for more detail. + /// + /// Return `None` if can't get repartitioned(empty/compressed file). + fn repartitioned( + &self, + target_partitions: usize, + config: &ConfigOptions, + ) -> Result>> { + let repartition_file_min_size = config.optimizer.repartition_file_min_size; + // Parallel execution on compressed CSV file is not supported yet. + if self.file_compression_type.is_compressed() { + return Ok(None); + } + + let repartitioned_file_groups_option = FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_preserve_order_within_groups(self.output_ordering().is_some()) + .with_repartition_file_min_size(repartition_file_min_size) + .repartition_file_groups(&self.base_config.file_groups); + + if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { + let mut new_plan = self.clone(); + new_plan.base_config.file_groups = repartitioned_file_groups; + return Ok(Some(Arc::new(new_plan))); + } + Ok(None) + } + fn execute( &self, partition: usize, @@ -234,8 +229,8 @@ impl ExecutionPlan for CsvExec { Ok(Box::pin(stream) as SendableRecordBatchStream) } - fn statistics(&self) -> Statistics { - self.projected_statistics.clone() + fn statistics(&self) -> Result { + Ok(self.projected_statistics.clone()) } fn metrics(&self) -> Option { @@ -289,7 +284,7 @@ impl CsvConfig { let mut builder = csv::ReaderBuilder::new(self.file_schema.clone()) .with_delimiter(self.delimiter) .with_batch_size(self.batch_size) - .has_header(self.has_header) + .with_header(self.has_header) .with_quote(self.quote); if let Some(proj) = &self.file_projection { @@ -322,47 +317,6 @@ impl CsvOpener { } } -/// Returns the offset of the first newline in the object store range [start, end), or the end offset if no newline is found. -async fn find_first_newline( - object_store: &Arc, - location: &object_store::path::Path, - start_byte: usize, - end_byte: usize, -) -> Result { - let options = GetOptions { - range: Some(Range { - start: start_byte, - end: end_byte, - }), - ..Default::default() - }; - - let r = object_store.get_opts(location, options).await?; - let mut input = r.into_stream(); - - let mut buffered = Bytes::new(); - let mut index = 0; - - loop { - if buffered.is_empty() { - match input.next().await { - Some(Ok(b)) => buffered = b, - Some(Err(e)) => return Err(e.into()), - None => return Ok(index), - }; - } - - for byte in &buffered { - if *byte == b'\n' { - return Ok(index); - } - index += 1; - } - - buffered.advance(buffered.len()); - } -} - impl FileOpener for CsvOpener { /// Open a partitioned CSV file. /// @@ -412,44 +366,20 @@ impl FileOpener for CsvOpener { ); } + let store = self.config.object_store.clone(); + Ok(Box::pin(async move { - let file_size = file_meta.object_meta.size; // Current partition contains bytes [start_byte, end_byte) (might contain incomplete lines at boundaries) - let range = match file_meta.range { - None => None, - Some(FileRange { start, end }) => { - let (start, end) = (start as usize, end as usize); - // Partition byte range is [start, end), the boundary might be in the middle of - // some line. Need to find out the exact line boundaries. - let start_delta = if start != 0 { - find_first_newline( - &config.object_store, - file_meta.location(), - start - 1, - file_size, - ) - .await? - } else { - 0 - }; - let end_delta = if end != file_size { - find_first_newline( - &config.object_store, - file_meta.location(), - end - 1, - file_size, - ) - .await? - } else { - 0 - }; - let range = start + start_delta..end + end_delta; - if range.start == range.end { - return Ok( - futures::stream::poll_fn(move |_| Poll::Ready(None)).boxed() - ); - } - Some(range) + + let calculated_range = calculate_range(&file_meta, &store).await?; + + let range = match calculated_range { + RangeCalculation::Range(None) => None, + RangeCalculation::Range(Some(range)) => Some(range), + RangeCalculation::TerminateEarly => { + return Ok( + futures::stream::poll_fn(move |_| Poll::Ready(None)).boxed() + ) } }; @@ -457,10 +387,8 @@ impl FileOpener for CsvOpener { range, ..Default::default() }; - let result = config - .object_store - .get_opts(file_meta.location(), options) - .await?; + + let result = store.get_opts(file_meta.location(), options).await?; match result.payload { GetResultPayload::File(mut file, _) => { @@ -538,7 +466,7 @@ pub async fn plan_to_csv( let mut write_headers = true; while let Some(batch) = stream.next().await.transpose()? { let mut writer = csv::WriterBuilder::new() - .has_headers(write_headers) + .with_header(write_headers) .build(buffer); writer.write(&batch)?; buffer = writer.into_inner(); @@ -871,9 +799,8 @@ mod tests { let mut config = partitioned_csv_config(file_schema, file_groups)?; // Add partition columns - config.table_partition_cols = vec![("date".to_owned(), DataType::Utf8)]; - config.file_groups[0][0].partition_values = - vec![ScalarValue::Utf8(Some("2021-10-26".to_owned()))]; + config.table_partition_cols = vec![Field::new("date", DataType::Utf8, false)]; + config.file_groups[0][0].partition_values = vec![ScalarValue::from("2021-10-26")]; // We should be able to project on the partition column // Which is supposed to be after the file fields @@ -1079,8 +1006,9 @@ mod tests { async fn write_csv_results() -> Result<()> { // create partitioned input file and context let tmp_dir = TempDir::new()?; - let ctx = - SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(8), + ); let schema = populate_csv_partitions(&tmp_dir, 8, ".csv")?; diff --git a/datafusion/core/src/datasource/physical_plan/file_groups.rs b/datafusion/core/src/datasource/physical_plan/file_groups.rs new file mode 100644 index 000000000000..6456bd5c7276 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/file_groups.rs @@ -0,0 +1,826 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Logic for managing groups of [`PartitionedFile`]s in DataFusion + +use crate::datasource::listing::{FileRange, PartitionedFile}; +use itertools::Itertools; +use std::cmp::min; +use std::collections::BinaryHeap; +use std::iter::repeat_with; + +/// Repartition input files into `target_partitions` partitions, if total file size exceed +/// `repartition_file_min_size` +/// +/// This partitions evenly by file byte range, and does not have any knowledge +/// of how data is laid out in specific files. The specific `FileOpener` are +/// responsible for the actual partitioning on specific data source type. (e.g. +/// the `CsvOpener` will read lines overlap with byte range as well as +/// handle boundaries to ensure all lines will be read exactly once) +/// +/// # Example +/// +/// For example, if there are two files `A` and `B` that we wish to read with 4 +/// partitions (with 4 threads) they will be divided as follows: +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// ┌─────────────────┐ +/// │ │ │ │ +/// │ File A │ +/// │ │ Range: 0-2MB │ │ +/// │ │ +/// │ └─────────────────┘ │ +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌─────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range 2-4MB │ │ +/// │ │ │ │ +/// │ │ │ └─────────────────┘ │ +/// │ File A (7MB) │ ────────▶ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range: 4-6MB │ │ +/// │ │ │ │ +/// │ │ │ └─────────────────┘ │ +/// └─────────────────┘ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌─────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ File B (1MB) │ ┌─────────────────┐ +/// │ │ │ │ File A │ │ +/// └─────────────────┘ │ Range: 6-7MB │ +/// │ └─────────────────┘ │ +/// ┌─────────────────┐ +/// │ │ File B (1MB) │ │ +/// │ │ +/// │ └─────────────────┘ │ +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// +/// If target_partitions = 4, +/// divides into 4 groups +/// ``` +/// +/// # Maintaining Order +/// +/// Within each group files are read sequentially. Thus, if the overall order of +/// tuples must be preserved, multiple files can not be mixed in the same group. +/// +/// In this case, the code will split the largest files evenly into any +/// available empty groups, but the overall distribution may not not be as even +/// as as even as if the order did not need to be preserved. +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// ┌─────────────────┐ +/// │ │ │ │ +/// │ File A │ +/// │ │ Range: 0-2MB │ │ +/// │ │ +/// ┌─────────────────┐ │ └─────────────────┘ │ +/// │ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range 2-4MB │ │ +/// │ File A (6MB) │ ────────▶ │ │ +/// │ (ordered) │ │ └─────────────────┘ │ +/// │ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range: 4-6MB │ │ +/// └─────────────────┘ │ │ +/// ┌─────────────────┐ │ └─────────────────┘ │ +/// │ File B (1MB) │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ (ordered) │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// └─────────────────┘ ┌─────────────────┐ +/// │ │ File B (1MB) │ │ +/// │ │ +/// │ └─────────────────┘ │ +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// +/// If target_partitions = 4, +/// divides into 4 groups +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct FileGroupPartitioner { + /// how many partitions should be created + target_partitions: usize, + /// the minimum size for a file to be repartitioned. + repartition_file_min_size: usize, + /// if the order when reading the files must be preserved + preserve_order_within_groups: bool, +} + +impl Default for FileGroupPartitioner { + fn default() -> Self { + Self::new() + } +} + +impl FileGroupPartitioner { + /// Creates a new [`FileGroupPartitioner`] with default values: + /// 1. `target_partitions = 1` + /// 2. `repartition_file_min_size = 10MB` + /// 3. `preserve_order_within_groups = false` + pub fn new() -> Self { + Self { + target_partitions: 1, + repartition_file_min_size: 10 * 1024 * 1024, + preserve_order_within_groups: false, + } + } + + /// Set the target partitions + pub fn with_target_partitions(mut self, target_partitions: usize) -> Self { + self.target_partitions = target_partitions; + self + } + + /// Set the minimum size at which to repartition a file + pub fn with_repartition_file_min_size( + mut self, + repartition_file_min_size: usize, + ) -> Self { + self.repartition_file_min_size = repartition_file_min_size; + self + } + + /// Set whether the order of tuples within a file must be preserved + pub fn with_preserve_order_within_groups( + mut self, + preserve_order_within_groups: bool, + ) -> Self { + self.preserve_order_within_groups = preserve_order_within_groups; + self + } + + /// Repartition input files according to the settings on this [`FileGroupPartitioner`]. + /// + /// If no repartitioning is needed or possible, return `None`. + pub fn repartition_file_groups( + &self, + file_groups: &[Vec], + ) -> Option>> { + if file_groups.is_empty() { + return None; + } + + // Perform redistribution only in case all files should be read from beginning to end + let has_ranges = file_groups.iter().flatten().any(|f| f.range.is_some()); + if has_ranges { + return None; + } + + // special case when order must be preserved + if self.preserve_order_within_groups { + self.repartition_preserving_order(file_groups) + } else { + self.repartition_evenly_by_size(file_groups) + } + } + + /// Evenly repartition files across partitions by size, ignoring any + /// existing grouping / ordering + fn repartition_evenly_by_size( + &self, + file_groups: &[Vec], + ) -> Option>> { + let target_partitions = self.target_partitions; + let repartition_file_min_size = self.repartition_file_min_size; + let flattened_files = file_groups.iter().flatten().collect::>(); + + let total_size = flattened_files + .iter() + .map(|f| f.object_meta.size as i64) + .sum::(); + if total_size < (repartition_file_min_size as i64) || total_size == 0 { + return None; + } + + let target_partition_size = + (total_size as usize + (target_partitions) - 1) / (target_partitions); + + let current_partition_index: usize = 0; + let current_partition_size: usize = 0; + + // Partition byte range evenly for all `PartitionedFile`s + let repartitioned_files = flattened_files + .into_iter() + .scan( + (current_partition_index, current_partition_size), + |state, source_file| { + let mut produced_files = vec![]; + let mut range_start = 0; + while range_start < source_file.object_meta.size { + let range_end = min( + range_start + (target_partition_size - state.1), + source_file.object_meta.size, + ); + + let mut produced_file = source_file.clone(); + produced_file.range = Some(FileRange { + start: range_start as i64, + end: range_end as i64, + }); + produced_files.push((state.0, produced_file)); + + if state.1 + (range_end - range_start) >= target_partition_size { + state.0 += 1; + state.1 = 0; + } else { + state.1 += range_end - range_start; + } + range_start = range_end; + } + Some(produced_files) + }, + ) + .flatten() + .group_by(|(partition_idx, _)| *partition_idx) + .into_iter() + .map(|(_, group)| group.map(|(_, vals)| vals).collect_vec()) + .collect_vec(); + + Some(repartitioned_files) + } + + /// Redistribute file groups across size preserving order + fn repartition_preserving_order( + &self, + file_groups: &[Vec], + ) -> Option>> { + // Can't repartition and preserve order if there are more groups + // than partitions + if file_groups.len() >= self.target_partitions { + return None; + } + let num_new_groups = self.target_partitions - file_groups.len(); + + // If there is only a single file + if file_groups.len() == 1 && file_groups[0].len() == 1 { + return self.repartition_evenly_by_size(file_groups); + } + + // Find which files could be split (single file groups) + let mut heap: BinaryHeap<_> = file_groups + .iter() + .enumerate() + .filter_map(|(group_index, group)| { + // ignore groups that do not have exactly 1 file + if group.len() == 1 { + Some(ToRepartition { + source_index: group_index, + file_size: group[0].object_meta.size, + new_groups: vec![group_index], + }) + } else { + None + } + }) + .collect(); + + // No files can be redistributed + if heap.is_empty() { + return None; + } + + // Add new empty groups to which we will redistribute ranges of existing files + let mut file_groups: Vec<_> = file_groups + .iter() + .cloned() + .chain(repeat_with(Vec::new).take(num_new_groups)) + .collect(); + + // Divide up empty groups + for (group_index, group) in file_groups.iter().enumerate() { + if !group.is_empty() { + continue; + } + // Pick the file that has the largest ranges to read so far + let mut largest_group = heap.pop().unwrap(); + largest_group.new_groups.push(group_index); + heap.push(largest_group); + } + + // Distribute files to their newly assigned groups + while let Some(to_repartition) = heap.pop() { + let range_size = to_repartition.range_size() as i64; + let ToRepartition { + source_index, + file_size, + new_groups, + } = to_repartition; + assert_eq!(file_groups[source_index].len(), 1); + let original_file = file_groups[source_index].pop().unwrap(); + + let last_group = new_groups.len() - 1; + let mut range_start: i64 = 0; + let mut range_end: i64 = range_size; + for (i, group_index) in new_groups.into_iter().enumerate() { + let target_group = &mut file_groups[group_index]; + assert!(target_group.is_empty()); + + // adjust last range to include the entire file + if i == last_group { + range_end = file_size as i64; + } + target_group + .push(original_file.clone().with_range(range_start, range_end)); + range_start = range_end; + range_end += range_size; + } + } + + Some(file_groups) + } +} + +/// Tracks how a individual file will be repartitioned +#[derive(Debug, Clone, PartialEq, Eq)] +struct ToRepartition { + /// the index from which the original file will be taken + source_index: usize, + /// the size of the original file + file_size: usize, + /// indexes of which group(s) will this be distributed to (including `source_index`) + new_groups: Vec, +} + +impl ToRepartition { + // how big will each file range be when this file is read in its new groups? + fn range_size(&self) -> usize { + self.file_size / self.new_groups.len() + } +} + +impl PartialOrd for ToRepartition { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +/// Order based on individual range +impl Ord for ToRepartition { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.range_size().cmp(&other.range_size()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + /// Empty file won't get partitioned + #[test] + fn repartition_empty_file_only() { + let partitioned_file_empty = pfile("empty", 0); + let file_group = vec![vec![partitioned_file_empty.clone()]]; + + let partitioned_files = FileGroupPartitioner::new() + .with_target_partitions(4) + .with_repartition_file_min_size(0) + .repartition_file_groups(&file_group); + + assert_partitioned_files(None, partitioned_files); + } + + /// Repartition when there is a empty file in file groups + #[test] + fn repartition_empty_files() { + let pfile_a = pfile("a", 10); + let pfile_b = pfile("b", 10); + let pfile_empty = pfile("empty", 0); + + let empty_first = vec![ + vec![pfile_empty.clone()], + vec![pfile_a.clone()], + vec![pfile_b.clone()], + ]; + let empty_middle = vec![ + vec![pfile_a.clone()], + vec![pfile_empty.clone()], + vec![pfile_b.clone()], + ]; + let empty_last = vec![vec![pfile_a], vec![pfile_b], vec![pfile_empty]]; + + // Repartition file groups into x partitions + let expected_2 = vec![ + vec![pfile("a", 10).with_range(0, 10)], + vec![pfile("b", 10).with_range(0, 10)], + ]; + let expected_3 = vec![ + vec![pfile("a", 10).with_range(0, 7)], + vec![ + pfile("a", 10).with_range(7, 10), + pfile("b", 10).with_range(0, 4), + ], + vec![pfile("b", 10).with_range(4, 10)], + ]; + + let file_groups_tests = [empty_first, empty_middle, empty_last]; + + for fg in file_groups_tests { + let all_expected = [(2, expected_2.clone()), (3, expected_3.clone())]; + for (n_partition, expected) in all_expected { + let actual = FileGroupPartitioner::new() + .with_target_partitions(n_partition) + .with_repartition_file_min_size(10) + .repartition_file_groups(&fg); + + assert_partitioned_files(Some(expected), actual); + } + } + } + + #[test] + fn repartition_single_file() { + // Single file, single partition into multiple partitions + let single_partition = vec![vec![pfile("a", 123)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&single_partition); + + let expected = Some(vec![ + vec![pfile("a", 123).with_range(0, 31)], + vec![pfile("a", 123).with_range(31, 62)], + vec![pfile("a", 123).with_range(62, 93)], + vec![pfile("a", 123).with_range(93, 123)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_too_much_partitions() { + // Single file, single partition into 96 partitions + let partitioned_file = pfile("a", 8); + let single_partition = vec![vec![partitioned_file]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(96) + .with_repartition_file_min_size(5) + .repartition_file_groups(&single_partition); + + let expected = Some(vec![ + vec![pfile("a", 8).with_range(0, 1)], + vec![pfile("a", 8).with_range(1, 2)], + vec![pfile("a", 8).with_range(2, 3)], + vec![pfile("a", 8).with_range(3, 4)], + vec![pfile("a", 8).with_range(4, 5)], + vec![pfile("a", 8).with_range(5, 6)], + vec![pfile("a", 8).with_range(6, 7)], + vec![pfile("a", 8).with_range(7, 8)], + ]); + + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_multiple_partitions() { + // Multiple files in single partition after redistribution + let source_partitions = vec![vec![pfile("a", 40)], vec![pfile("b", 60)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(3) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + vec![pfile("a", 40).with_range(0, 34)], + vec![ + pfile("a", 40).with_range(34, 40), + pfile("b", 60).with_range(0, 28), + ], + vec![pfile("b", 60).with_range(28, 60)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_same_num_partitions() { + // "Rebalance" files across partitions + let source_partitions = vec![vec![pfile("a", 40)], vec![pfile("b", 60)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(2) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + vec![ + pfile("a", 40).with_range(0, 40), + pfile("b", 60).with_range(0, 10), + ], + vec![pfile("b", 60).with_range(10, 60)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_no_action_ranges() { + // No action due to Some(range) in second file + let source_partitions = vec![ + vec![pfile("a", 123)], + vec![pfile("b", 144).with_range(1, 50)], + ]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(65) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_no_action_min_size() { + // No action due to target_partition_size + let single_partition = vec![vec![pfile("a", 123)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(65) + .with_repartition_file_min_size(500) + .repartition_file_groups(&single_partition); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_no_action_zero_files() { + // No action due to no files + let empty_partition = vec![]; + + let partitioner = FileGroupPartitioner::new() + .with_target_partitions(65) + .with_repartition_file_min_size(500); + + assert_partitioned_files(None, repartition_test(partitioner, empty_partition)) + } + + #[test] + fn repartition_ordered_no_action_too_few_partitions() { + // No action as there are no new groups to redistribute to + let input_partitions = vec![vec![pfile("a", 100)], vec![pfile("b", 200)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(2) + .with_repartition_file_min_size(10) + .repartition_file_groups(&input_partitions); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_ordered_no_action_file_too_small() { + // No action as there are no new groups to redistribute to + let single_partition = vec![vec![pfile("a", 100)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(2) + // file is too small to repartition + .with_repartition_file_min_size(1000) + .repartition_file_groups(&single_partition); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_ordered_one_large_file() { + // "Rebalance" the single large file across partitions + let source_partitions = vec![vec![pfile("a", 100)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(3) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + vec![pfile("a", 100).with_range(0, 34)], + vec![pfile("a", 100).with_range(34, 68)], + vec![pfile("a", 100).with_range(68, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_one_large_one_small_file() { + // "Rebalance" the single large file across empty partitions, but can't split + // small file + let source_partitions = vec![vec![pfile("a", 100)], vec![pfile("b", 30)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first third of "a" + vec![pfile("a", 100).with_range(0, 33)], + // only b in this group (can't do this) + vec![pfile("b", 30).with_range(0, 30)], + // second third of "a" + vec![pfile("a", 100).with_range(33, 66)], + // final third of "a" + vec![pfile("a", 100).with_range(66, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_two_large_files() { + // "Rebalance" two large files across empty partitions, but can't mix them + let source_partitions = vec![vec![pfile("a", 100)], vec![pfile("b", 100)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first half of "a" + vec![pfile("a", 100).with_range(0, 50)], + // scan first half of "b" + vec![pfile("b", 100).with_range(0, 50)], + // second half of "a" + vec![pfile("a", 100).with_range(50, 100)], + // second half of "b" + vec![pfile("b", 100).with_range(50, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_two_large_one_small_files() { + // "Rebalance" two large files and one small file across empty partitions + let source_partitions = vec![ + vec![pfile("a", 100)], + vec![pfile("b", 100)], + vec![pfile("c", 30)], + ]; + + let partitioner = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_repartition_file_min_size(10); + + // with 4 partitions, can only split the first large file "a" + let actual = partitioner + .with_target_partitions(4) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first half of "a" + vec![pfile("a", 100).with_range(0, 50)], + // All of "b" + vec![pfile("b", 100).with_range(0, 100)], + // All of "c" + vec![pfile("c", 30).with_range(0, 30)], + // second half of "a" + vec![pfile("a", 100).with_range(50, 100)], + ]); + assert_partitioned_files(expected, actual); + + // With 5 partitions, we can split both "a" and "b", but they can't be intermixed + let actual = partitioner + .with_target_partitions(5) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first half of "a" + vec![pfile("a", 100).with_range(0, 50)], + // scan first half of "b" + vec![pfile("b", 100).with_range(0, 50)], + // All of "c" + vec![pfile("c", 30).with_range(0, 30)], + // second half of "a" + vec![pfile("a", 100).with_range(50, 100)], + // second half of "b" + vec![pfile("b", 100).with_range(50, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_one_large_one_small_existing_empty() { + // "Rebalance" files using existing empty partition + let source_partitions = + vec![vec![pfile("a", 100)], vec![], vec![pfile("b", 40)], vec![]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(5) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + // Of the three available groups (2 original empty and 1 new from the + // target partitions), assign two to "a" and one to "b" + let expected = Some(vec![ + // Scan of "a" across three groups + vec![pfile("a", 100).with_range(0, 33)], + vec![pfile("a", 100).with_range(33, 66)], + // scan first half of "b" + vec![pfile("b", 40).with_range(0, 20)], + // final third of "a" + vec![pfile("a", 100).with_range(66, 100)], + // second half of "b" + vec![pfile("b", 40).with_range(20, 40)], + ]); + assert_partitioned_files(expected, actual); + } + #[test] + fn repartition_ordered_existing_group_multiple_files() { + // groups with multiple files in a group can not be changed, but can divide others + let source_partitions = vec![ + // two files in an existing partition + vec![pfile("a", 100), pfile("b", 100)], + vec![pfile("c", 40)], + ]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(3) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + // Of the three available groups (2 original empty and 1 new from the + // target partitions), assign two to "a" and one to "b" + let expected = Some(vec![ + // don't try and rearrange files in the existing partition + // assuming that the caller had a good reason to put them that way. + // (it is technically possible to split off ranges from the files if desired) + vec![pfile("a", 100), pfile("b", 100)], + // first half of "c" + vec![pfile("c", 40).with_range(0, 20)], + // second half of "c" + vec![pfile("c", 40).with_range(20, 40)], + ]); + assert_partitioned_files(expected, actual); + } + + /// Asserts that the two groups of `ParititonedFile` are the same + /// (PartitionedFile doesn't implement PartialEq) + fn assert_partitioned_files( + expected: Option>>, + actual: Option>>, + ) { + match (expected, actual) { + (None, None) => {} + (Some(_), None) => panic!("Expected Some, got None"), + (None, Some(_)) => panic!("Expected None, got Some"), + (Some(expected), Some(actual)) => { + let expected_string = format!("{:#?}", expected); + let actual_string = format!("{:#?}", actual); + assert_eq!(expected_string, actual_string); + } + } + } + + /// returns a partitioned file with the specified path and size + fn pfile(path: impl Into, file_size: u64) -> PartitionedFile { + PartitionedFile::new(path, file_size) + } + + /// repartition the file groups both with and without preserving order + /// asserting they return the same value and returns that value + fn repartition_test( + partitioner: FileGroupPartitioner, + file_groups: Vec>, + ) -> Option>> { + let repartitioned = partitioner.repartition_file_groups(&file_groups); + + let repartitioned_preserving_sort = partitioner + .with_preserve_order_within_groups(true) + .repartition_file_groups(&file_groups); + + assert_partitioned_files( + repartitioned.clone(), + repartitioned_preserving_sort.clone(), + ); + repartitioned + } +} diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 819bfabae290..516755e4d293 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -18,10 +18,12 @@ //! [`FileScanConfig`] to configure scanning of possibly partitioned //! file sources. -use crate::datasource::{ - listing::{FileRange, PartitionedFile}, - object_store::ObjectStoreUrl, +use std::{ + borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc, vec, }; + +use super::{get_projected_output_ordering, FileGroupPartitioner}; +use crate::datasource::{listing::PartitionedFile, object_store::ObjectStoreUrl}; use crate::{ error::{DataFusionError, Result}, scalar::ScalarValue, @@ -30,20 +32,13 @@ use crate::{ use arrow::array::{ArrayData, BufferBuilder}; use arrow::buffer::Buffer; use arrow::datatypes::{ArrowNativeType, UInt16Type}; -use arrow_array::{ArrayRef, DictionaryArray, RecordBatch}; +use arrow_array::{ArrayRef, DictionaryArray, RecordBatch, RecordBatchOptions}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::exec_err; -use datafusion_common::{ColumnStatistics, Statistics}; +use datafusion_common::stats::Precision; +use datafusion_common::{exec_err, ColumnStatistics, Statistics}; use datafusion_physical_expr::LexOrdering; -use itertools::Itertools; use log::warn; -use std::{ - borrow::Cow, cmp::min, collections::HashMap, fmt::Debug, marker::PhantomData, - sync::Arc, vec, -}; - -use super::get_projected_output_ordering; /// Convert type to a type suitable for use as a [`ListingTable`] /// partition column. Returns `Dictionary(UInt16, val_type)`, which is @@ -101,11 +96,9 @@ pub struct FileScanConfig { /// all records after filtering are returned. pub limit: Option, /// The partitioning columns - pub table_partition_cols: Vec<(String, DataType)>, + pub table_partition_cols: Vec, /// All equivalent lexicographical orderings that describe the schema. pub output_ordering: Vec, - /// Indicates whether this plan may produce an infinite stream of records. - pub infinite_source: bool, } impl FileScanConfig { @@ -130,30 +123,22 @@ impl FileScanConfig { let mut table_cols_stats = vec![]; for idx in proj_iter { if idx < self.file_schema.fields().len() { - table_fields.push(self.file_schema.field(idx).clone()); - if let Some(file_cols_stats) = &self.statistics.column_statistics { - table_cols_stats.push(file_cols_stats[idx].clone()) - } else { - table_cols_stats.push(ColumnStatistics::default()) - } + let field = self.file_schema.field(idx); + table_fields.push(field.clone()); + table_cols_stats.push(self.statistics.column_statistics[idx].clone()) } else { let partition_idx = idx - self.file_schema.fields().len(); - table_fields.push(Field::new( - &self.table_partition_cols[partition_idx].0, - self.table_partition_cols[partition_idx].1.to_owned(), - false, - )); + table_fields.push(self.table_partition_cols[partition_idx].to_owned()); // TODO provide accurate stat for partition column (#1186) - table_cols_stats.push(ColumnStatistics::default()) + table_cols_stats.push(ColumnStatistics::new_unknown()) } } let table_stats = Statistics { - num_rows: self.statistics.num_rows, - is_exact: self.statistics.is_exact, + num_rows: self.statistics.num_rows.clone(), // TODO correct byte size? - total_byte_size: None, - column_statistics: Some(table_cols_stats), + total_byte_size: Precision::Absent, + column_statistics: table_cols_stats, }; let table_schema = Arc::new( @@ -184,79 +169,17 @@ impl FileScanConfig { }) } - /// Repartition all input files into `target_partitions` partitions, if total file size exceed - /// `repartition_file_min_size` - /// `target_partitions` and `repartition_file_min_size` directly come from configuration. - /// - /// This function only try to partition file byte range evenly, and let specific `FileOpener` to - /// do actual partition on specific data source type. (e.g. `CsvOpener` will only read lines - /// overlap with byte range but also handle boundaries to ensure all lines will be read exactly once) + #[allow(missing_docs)] + #[deprecated(since = "33.0.0", note = "Use SessionContext::new_with_config")] pub fn repartition_file_groups( file_groups: Vec>, target_partitions: usize, repartition_file_min_size: usize, ) -> Option>> { - let flattened_files = file_groups.iter().flatten().collect::>(); - - // Perform redistribution only in case all files should be read from beginning to end - let has_ranges = flattened_files.iter().any(|f| f.range.is_some()); - if has_ranges { - return None; - } - - let total_size = flattened_files - .iter() - .map(|f| f.object_meta.size as i64) - .sum::(); - if total_size < (repartition_file_min_size as i64) || total_size == 0 { - return None; - } - - let target_partition_size = - (total_size as usize + (target_partitions) - 1) / (target_partitions); - - let current_partition_index: usize = 0; - let current_partition_size: usize = 0; - - // Partition byte range evenly for all `PartitionedFile`s - let repartitioned_files = flattened_files - .into_iter() - .scan( - (current_partition_index, current_partition_size), - |state, source_file| { - let mut produced_files = vec![]; - let mut range_start = 0; - while range_start < source_file.object_meta.size { - let range_end = min( - range_start + (target_partition_size - state.1), - source_file.object_meta.size, - ); - - let mut produced_file = source_file.clone(); - produced_file.range = Some(FileRange { - start: range_start as i64, - end: range_end as i64, - }); - produced_files.push((state.0, produced_file)); - - if state.1 + (range_end - range_start) >= target_partition_size { - state.0 += 1; - state.1 = 0; - } else { - state.1 += range_end - range_start; - } - range_start = range_end; - } - Some(produced_files) - }, - ) - .flatten() - .group_by(|(partition_idx, _)| *partition_idx) - .into_iter() - .map(|(_, group)| group.map(|(_, vals)| vals).collect_vec()) - .collect_vec(); - - Some(repartitioned_files) + FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_repartition_file_min_size(repartition_file_min_size) + .repartition_file_groups(&file_groups) } } @@ -344,10 +267,16 @@ impl PartitionColumnProjector { &mut self.key_buffer_cache, partition_value.as_ref(), file_batch.num_rows(), - ), + )?, ) } - RecordBatch::try_new(Arc::clone(&self.projected_schema), cols).map_err(Into::into) + + RecordBatch::try_new_with_options( + Arc::clone(&self.projected_schema), + cols, + &RecordBatchOptions::new().with_row_count(Some(file_batch.num_rows())), + ) + .map_err(Into::into) } } @@ -398,11 +327,11 @@ fn create_dict_array( dict_val: &ScalarValue, len: usize, data_type: DataType, -) -> ArrayRef +) -> Result where T: ArrowNativeType, { - let dict_vals = dict_val.to_array(); + let dict_vals = dict_val.to_array()?; let sliced_key_buffer = buffer_gen.get_buffer(len); @@ -411,16 +340,16 @@ where .len(len) .add_buffer(sliced_key_buffer); builder = builder.add_child_data(dict_vals.to_data()); - Arc::new(DictionaryArray::::from( + Ok(Arc::new(DictionaryArray::::from( builder.build().unwrap(), - )) + ))) } fn create_output_array( key_buffer_cache: &mut ZeroBufferGenerators, val: &ScalarValue, len: usize, -) -> ArrayRef { +) -> Result { if let ScalarValue::Dictionary(key_type, dict_val) = &val { match key_type.as_ref() { DataType::Int8 => { @@ -507,11 +436,11 @@ mod tests { let conf = config_for_projection( Arc::clone(&file_schema), None, - Statistics::default(), - vec![( + Statistics::new_unknown(&file_schema), + to_partition_cols(vec![( "date".to_owned(), wrap_partition_type_in_dict(DataType::Utf8), - )], + )]), ); let (proj_schema, proj_statistics, _) = conf.project(); @@ -522,10 +451,7 @@ mod tests { "partition columns are the last columns" ); assert_eq!( - proj_statistics - .column_statistics - .expect("projection creates column statistics") - .len(), + proj_statistics.column_statistics.len(), file_schema.fields().len() + 1 ); // TODO implement tests for partition column statistics once implemented @@ -537,6 +463,35 @@ mod tests { assert_eq!(col_indices, None); } + #[test] + fn physical_plan_config_no_projection_tab_cols_as_field() { + let file_schema = aggr_test_schema(); + + // make a table_partition_col as a field + let table_partition_col = + Field::new("date", wrap_partition_type_in_dict(DataType::Utf8), true) + .with_metadata(HashMap::from_iter(vec![( + "key_whatever".to_owned(), + "value_whatever".to_owned(), + )])); + + let conf = config_for_projection( + Arc::clone(&file_schema), + None, + Statistics::new_unknown(&file_schema), + vec![table_partition_col.clone()], + ); + + // verify the proj_schema inlcudes the last column and exactly the same the field it is defined + let (proj_schema, _proj_statistics, _) = conf.project(); + assert_eq!(proj_schema.fields().len(), file_schema.fields().len() + 1); + assert_eq!( + *proj_schema.field(file_schema.fields().len()), + table_partition_col, + "partition columns are the last columns and ust have all values defined in created field" + ); + } + #[test] fn physical_plan_config_with_projection() { let file_schema = aggr_test_schema(); @@ -544,23 +499,21 @@ mod tests { Arc::clone(&file_schema), Some(vec![file_schema.fields().len(), 0]), Statistics { - num_rows: Some(10), + num_rows: Precision::Inexact(10), // assign the column index to distinct_count to help assert // the source statistic after the projection - column_statistics: Some( - (0..file_schema.fields().len()) - .map(|i| ColumnStatistics { - distinct_count: Some(i), - ..Default::default() - }) - .collect(), - ), - ..Default::default() + column_statistics: (0..file_schema.fields().len()) + .map(|i| ColumnStatistics { + distinct_count: Precision::Inexact(i), + ..Default::default() + }) + .collect(), + total_byte_size: Precision::Absent, }, - vec![( + to_partition_cols(vec![( "date".to_owned(), wrap_partition_type_in_dict(DataType::Utf8), - )], + )]), ); let (proj_schema, proj_statistics, _) = conf.project(); @@ -568,13 +521,11 @@ mod tests { columns(&proj_schema), vec!["date".to_owned(), "c1".to_owned()] ); - let proj_stat_cols = proj_statistics - .column_statistics - .expect("projection creates column statistics"); + let proj_stat_cols = proj_statistics.column_statistics; assert_eq!(proj_stat_cols.len(), 2); // TODO implement tests for proj_stat_cols[0] once partition column // statistics are implemented - assert_eq!(proj_stat_cols[1].distinct_count, Some(0)); + assert_eq!(proj_stat_cols[1].distinct_count, Precision::Inexact(0)); let col_names = conf.projected_file_column_names(); assert_eq!(col_names, Some(vec!["c1".to_owned()])); @@ -615,8 +566,8 @@ mod tests { file_batch.schema().fields().len(), file_batch.schema().fields().len() + 2, ]), - Statistics::default(), - partition_cols.clone(), + Statistics::new_unknown(&file_batch.schema()), + to_partition_cols(partition_cols.clone()), ); let (proj_schema, ..) = conf.project(); // created a projector for that projected schema @@ -634,15 +585,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "2021".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "10".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "26".to_owned(), - ))), + wrap_partition_value_in_dict(ScalarValue::from("2021")), + wrap_partition_value_in_dict(ScalarValue::from("10")), + wrap_partition_value_in_dict(ScalarValue::from("26")), ], ) .expect("Projection of partition columns into record batch failed"); @@ -668,15 +613,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "2021".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "10".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "27".to_owned(), - ))), + wrap_partition_value_in_dict(ScalarValue::from("2021")), + wrap_partition_value_in_dict(ScalarValue::from("10")), + wrap_partition_value_in_dict(ScalarValue::from("27")), ], ) .expect("Projection of partition columns into record batch failed"); @@ -704,15 +643,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "2021".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "10".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "28".to_owned(), - ))), + wrap_partition_value_in_dict(ScalarValue::from("2021")), + wrap_partition_value_in_dict(ScalarValue::from("10")), + wrap_partition_value_in_dict(ScalarValue::from("28")), ], ) .expect("Projection of partition columns into record batch failed"); @@ -738,9 +671,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - ScalarValue::Utf8(Some("2021".to_owned())), - ScalarValue::Utf8(Some("10".to_owned())), - ScalarValue::Utf8(Some("26".to_owned())), + ScalarValue::from("2021"), + ScalarValue::from("10"), + ScalarValue::from("26"), ], ) .expect("Projection of partition columns into record batch failed"); @@ -761,7 +694,7 @@ mod tests { file_schema: SchemaRef, projection: Option>, statistics: Statistics, - table_partition_cols: Vec<(String, DataType)>, + table_partition_cols: Vec, ) -> FileScanConfig { FileScanConfig { file_schema, @@ -772,10 +705,17 @@ mod tests { statistics, table_partition_cols, output_ordering: vec![], - infinite_source: false, } } + /// Convert partition columns from Vec to Vec + fn to_partition_cols(table_partition_cols: Vec<(String, DataType)>) -> Vec { + table_partition_cols + .iter() + .map(|(name, dtype)| Field::new(name, dtype.clone(), false)) + .collect::>() + } + /// returns record batch with 3 columns of i32 in memory pub fn build_table_i32( a: (&str, &Vec), diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index af304f40dd86..bb4c8313642c 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -112,7 +112,7 @@ enum FileStreamState { /// The idle state, no file is currently being read Idle, /// Currently performing asynchronous IO to obtain a stream of RecordBatch - /// for a given parquet file + /// for a given file Open { /// A [`FileOpenFuture`] returned by [`FileOpener::open`] future: FileOpenFuture, @@ -259,7 +259,7 @@ impl FileStream { &config .table_partition_cols .iter() - .map(|x| x.0.clone()) + .map(|x| x.name().clone()) .collect::>(), ); @@ -518,9 +518,8 @@ impl RecordBatchStream for FileStream { #[cfg(test)] mod tests { - use arrow_schema::Schema; - use datafusion_common::internal_err; - use datafusion_common::DataFusionError; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; use super::*; use crate::datasource::file_format::write::BatchSerializer; @@ -533,8 +532,8 @@ mod tests { test::{make_partition, object_store::register_test_store}, }; - use std::sync::atomic::{AtomicUsize, Ordering}; - use std::sync::Arc; + use arrow_schema::Schema; + use datafusion_common::{internal_err, DataFusionError, Statistics}; use async_trait::async_trait; use bytes::Bytes; @@ -659,14 +658,13 @@ mod tests { let config = FileScanConfig { object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + statistics: Statistics::new_unknown(&file_schema), file_schema, file_groups: vec![file_group], - statistics: Default::default(), projection: None, limit: self.limit, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let metrics_set = ExecutionPlanMetricsSet::new(); let file_stream = FileStream::new(&config, 0, self.opener, &metrics_set) @@ -993,7 +991,7 @@ mod tests { #[async_trait] impl BatchSerializer for TestSerializer { - async fn serialize(&mut self, _batch: RecordBatch) -> Result { + async fn serialize(&self, _batch: RecordBatch, _initial: bool) -> Result { Ok(self.bytes.clone()) } } diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 537855704a76..529632dab85a 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -16,6 +16,13 @@ // under the License. //! Execution plan for reading line-delimited JSON files + +use std::any::Any; +use std::io::{BufReader, Read, Seek, SeekFrom}; +use std::sync::Arc; +use std::task::Poll; + +use super::{calculate_range, FileGroupPartitioner, FileScanConfig, RangeCalculation}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::listing::ListingTableUrl; use crate::datasource::physical_plan::file_stream::{ @@ -29,27 +36,19 @@ use crate::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; -use datafusion_execution::TaskContext; use arrow::json::ReaderBuilder; use arrow::{datatypes::SchemaRef, json}; -use datafusion_physical_expr::{ - ordering_equivalence_properties_helper, LexOrdering, OrderingEquivalenceProperties, -}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use bytes::{Buf, Bytes}; -use futures::{ready, stream, StreamExt, TryStreamExt}; -use object_store; +use futures::{ready, StreamExt, TryStreamExt}; +use object_store::{self, GetOptions}; use object_store::{GetResultPayload, ObjectStore}; -use std::any::Any; -use std::io::BufReader; -use std::sync::Arc; -use std::task::Poll; use tokio::io::AsyncWriteExt; use tokio::task::JoinSet; -use super::FileScanConfig; - /// Execution plan for scanning NdJson data source #[derive(Debug, Clone)] pub struct NdJsonExec { @@ -111,18 +110,14 @@ impl ExecutionPlan for NdJsonExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config.infinite_source) - } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering .first() .map(|ordering| ordering.as_slice()) } - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - ordering_equivalence_properties_helper( + fn equivalence_properties(&self) -> EquivalenceProperties { + EquivalenceProperties::new_with_orderings( self.schema(), &self.projected_output_ordering, ) @@ -139,6 +134,30 @@ impl ExecutionPlan for NdJsonExec { Ok(self) } + fn repartitioned( + &self, + target_partitions: usize, + config: &datafusion_common::config::ConfigOptions, + ) -> Result>> { + let repartition_file_min_size = config.optimizer.repartition_file_min_size; + let preserve_order_within_groups = self.output_ordering().is_some(); + let file_groups = &self.base_config.file_groups; + + let repartitioned_file_groups_option = FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_preserve_order_within_groups(preserve_order_within_groups) + .with_repartition_file_min_size(repartition_file_min_size) + .repartition_file_groups(file_groups); + + if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { + let mut new_plan = self.clone(); + new_plan.base_config.file_groups = repartitioned_file_groups; + return Ok(Some(Arc::new(new_plan))); + } + + Ok(None) + } + fn execute( &self, partition: usize, @@ -163,8 +182,8 @@ impl ExecutionPlan for NdJsonExec { Ok(Box::pin(stream) as SendableRecordBatchStream) } - fn statistics(&self) -> Statistics { - self.projected_statistics.clone() + fn statistics(&self) -> Result { + Ok(self.projected_statistics.clone()) } fn metrics(&self) -> Option { @@ -198,54 +217,89 @@ impl JsonOpener { } impl FileOpener for JsonOpener { + /// Open a partitioned NDJSON file. + /// + /// If `file_meta.range` is `None`, the entire file is opened. + /// Else `file_meta.range` is `Some(FileRange{start, end})`, which corresponds to the byte range [start, end) within the file. + /// + /// Note: `start` or `end` might be in the middle of some lines. In such cases, the following rules + /// are applied to determine which lines to read: + /// 1. The first line of the partition is the line in which the index of the first character >= `start`. + /// 2. The last line of the partition is the line in which the byte at position `end - 1` resides. + /// + /// See [`CsvOpener`](super::CsvOpener) for an example. fn open(&self, file_meta: FileMeta) -> Result { let store = self.object_store.clone(); let schema = self.projected_schema.clone(); let batch_size = self.batch_size; - let file_compression_type = self.file_compression_type.to_owned(); + Ok(Box::pin(async move { - let r = store.get(file_meta.location()).await?; - match r.payload { - GetResultPayload::File(file, _) => { - let bytes = file_compression_type.convert_read(file)?; + let calculated_range = calculate_range(&file_meta, &store).await?; + + let range = match calculated_range { + RangeCalculation::Range(None) => None, + RangeCalculation::Range(Some(range)) => Some(range), + RangeCalculation::TerminateEarly => { + return Ok( + futures::stream::poll_fn(move |_| Poll::Ready(None)).boxed() + ) + } + }; + + let options = GetOptions { + range, + ..Default::default() + }; + + let result = store.get_opts(file_meta.location(), options).await?; + + match result.payload { + GetResultPayload::File(mut file, _) => { + let bytes = match file_meta.range { + None => file_compression_type.convert_read(file)?, + Some(_) => { + file.seek(SeekFrom::Start(result.range.start as _))?; + let limit = result.range.end - result.range.start; + file_compression_type.convert_read(file.take(limit as u64))? + } + }; + let reader = ReaderBuilder::new(schema) .with_batch_size(batch_size) .build(BufReader::new(bytes))?; + Ok(futures::stream::iter(reader).boxed()) } GetResultPayload::Stream(s) => { + let s = s.map_err(DataFusionError::from); + let mut decoder = ReaderBuilder::new(schema) .with_batch_size(batch_size) .build_decoder()?; - - let s = s.map_err(DataFusionError::from); let mut input = file_compression_type.convert_stream(s.boxed())?.fuse(); - let mut buffered = Bytes::new(); + let mut buffer = Bytes::new(); - let s = stream::poll_fn(move |cx| { + let s = futures::stream::poll_fn(move |cx| { loop { - if buffered.is_empty() { - buffered = match ready!(input.poll_next_unpin(cx)) { - Some(Ok(b)) => b, + if buffer.is_empty() { + match ready!(input.poll_next_unpin(cx)) { + Some(Ok(b)) => buffer = b, Some(Err(e)) => { return Poll::Ready(Some(Err(e.into()))) } - None => break, + None => {} }; } - let read = buffered.len(); - let decoded = match decoder.decode(buffered.as_ref()) { + let decoded = match decoder.decode(buffer.as_ref()) { + Ok(0) => break, Ok(decoded) => decoded, Err(e) => return Poll::Ready(Some(Err(e))), }; - buffered.advance(decoded); - if decoded != read { - break; - } + buffer.advance(decoded); } Poll::Ready(decoder.flush().transpose()) @@ -358,9 +412,9 @@ mod tests { ) .unwrap(); let meta = file_groups - .get(0) + .first() .unwrap() - .get(0) + .first() .unwrap() .clone() .object_meta; @@ -392,9 +446,9 @@ mod tests { ) .unwrap(); let path = file_groups - .get(0) + .first() .unwrap() - .get(0) + .first() .unwrap() .object_meta .location @@ -457,13 +511,12 @@ mod tests { FileScanConfig { object_store_url, file_groups, + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: None, limit: Some(3), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); @@ -536,13 +589,12 @@ mod tests { FileScanConfig { object_store_url, file_groups, + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: None, limit: Some(3), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); @@ -584,13 +636,12 @@ mod tests { FileScanConfig { object_store_url, file_groups, + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: Some(vec![0, 2]), limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); @@ -637,13 +688,12 @@ mod tests { FileScanConfig { object_store_url, file_groups, + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: Some(vec![3, 0, 2]), limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); @@ -675,8 +725,9 @@ mod tests { #[tokio::test] async fn write_json_results() -> Result<()> { // create partitioned input file and context - let ctx = - SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(8), + ); let path = format!("{TEST_DATA_BASE}/1.json"); diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index d8ae6b3c04e6..d7be017a1868 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -20,37 +20,40 @@ mod arrow_file; mod avro; mod csv; +mod file_groups; +mod file_scan_config; mod file_stream; mod json; +#[cfg(feature = "parquet")] pub mod parquet; +pub use file_groups::FileGroupPartitioner; +use futures::StreamExt; pub(crate) use self::csv::plan_to_csv; pub use self::csv::{CsvConfig, CsvExec, CsvOpener}; -pub(crate) use self::parquet::plan_to_parquet; +pub(crate) use self::json::plan_to_json; +#[cfg(feature = "parquet")] pub use self::parquet::{ParquetExec, ParquetFileMetrics, ParquetFileReaderFactory}; -use arrow::{ - array::new_null_array, - compute::can_cast_types, - datatypes::{DataType, Schema, SchemaRef}, - record_batch::{RecordBatch, RecordBatchOptions}, -}; + pub use arrow_file::ArrowExec; pub use avro::AvroExec; -use datafusion_physical_expr::PhysicalSortExpr; -pub use file_stream::{FileOpenFuture, FileOpener, FileStream, OnError}; -pub(crate) use json::plan_to_json; -pub use json::{JsonOpener, NdJsonExec}; -mod file_scan_config; -pub(crate) use file_scan_config::PartitionColumnProjector; +use file_scan_config::PartitionColumnProjector; pub use file_scan_config::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, }; +pub use file_stream::{FileOpenFuture, FileOpener, FileStream, OnError}; +pub use json::{JsonOpener, NdJsonExec}; -use crate::error::{DataFusionError, Result}; -use crate::{ - datasource::file_format::write::FileWriterMode, - physical_plan::{DisplayAs, DisplayFormatType}, +use std::{ + fmt::{Debug, Formatter, Result as FmtResult}, + ops::Range, + sync::Arc, + vec, }; + +use super::listing::ListingTableUrl; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{DisplayAs, DisplayFormatType}; use crate::{ datasource::{ listing::{FileRange, PartitionedFile}, @@ -59,21 +62,20 @@ use crate::{ physical_plan::display::{OutputOrderingDisplay, ProjectSchemaDisplay}, }; +use arrow::{ + array::new_null_array, + compute::{can_cast_types, cast}, + datatypes::{DataType, Schema, SchemaRef}, + record_batch::{RecordBatch, RecordBatchOptions}, +}; use datafusion_common::{file_options::FileTypeWriterOptions, plan_err}; use datafusion_physical_expr::expressions::Column; - -use arrow::compute::cast; +use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::ExecutionPlan; + use log::debug; -use object_store::path::Path; use object_store::ObjectMeta; -use std::{ - fmt::{Debug, Formatter, Result as FmtResult}, - sync::Arc, - vec, -}; - -use super::listing::ListingTableUrl; +use object_store::{path::Path, GetOptions, ObjectStore}; /// The base configurations to provide when creating a physical plan for /// writing to any given file format. @@ -89,14 +91,10 @@ pub struct FileSinkConfig { /// A vector of column names and their corresponding data types, /// representing the partitioning columns for the file pub table_partition_cols: Vec<(String, DataType)>, - /// A writer mode that determines how data is written to the file - pub writer_mode: FileWriterMode, /// If true, it is assumed there is a single table_path which is a file to which all data should be written /// regardless of input partitioning. Otherwise, each table path is assumed to be a directory /// to which each output partition is written to its own output file. pub single_file_output: bool, - /// If input is unbounded, tokio tasks need to yield to not block execution forever - pub unbounded_input: bool, /// Controls whether existing data should be overwritten by this sink pub overwrite: bool, /// Contains settings specific to writing a given FileType, e.g. parquet max_row_group_size @@ -135,13 +133,24 @@ impl DisplayAs for FileScanConfig { write!(f, ", limit={limit}")?; } - if self.infinite_source { - write!(f, ", infinite_source=true")?; - } - if let Some(ordering) = orderings.first() { if !ordering.is_empty() { - write!(f, ", output_ordering={}", OutputOrderingDisplay(ordering))?; + let start = if orderings.len() == 1 { + ", output_ordering=" + } else { + ", output_orderings=[" + }; + write!(f, "{}", start)?; + for (idx, ordering) in + orderings.iter().enumerate().filter(|(_, o)| !o.is_empty()) + { + match idx { + 0 => write!(f, "{}", OutputOrderingDisplay(ordering))?, + _ => write!(f, ", {}", OutputOrderingDisplay(ordering))?, + } + } + let end = if orderings.len() == 1 { "" } else { "]" }; + write!(f, "{}", end)?; } } @@ -501,9 +510,9 @@ fn get_projected_output_ordering( all_orderings } -// Get output (un)boundedness information for the given `plan`. -pub(crate) fn is_plan_streaming(plan: &Arc) -> Result { - let result = if plan.children().is_empty() { +/// Get output (un)boundedness information for the given `plan`. +pub fn is_plan_streaming(plan: &Arc) -> Result { + if plan.children().is_empty() { plan.unbounded_output(&[]) } else { let children_unbounded_output = plan @@ -512,8 +521,110 @@ pub(crate) fn is_plan_streaming(plan: &Arc) -> Result { .map(is_plan_streaming) .collect::>>(); plan.unbounded_output(&children_unbounded_output?) + } +} + +/// Represents the possible outcomes of a range calculation. +/// +/// This enum is used to encapsulate the result of calculating the range of +/// bytes to read from an object (like a file) in an object store. +/// +/// Variants: +/// - `Range(Option>)`: +/// Represents a range of bytes to be read. It contains an `Option` wrapping a +/// `Range`. `None` signifies that the entire object should be read, +/// while `Some(range)` specifies the exact byte range to read. +/// - `TerminateEarly`: +/// Indicates that the range calculation determined no further action is +/// necessary, possibly because the calculated range is empty or invalid. +enum RangeCalculation { + Range(Option>), + TerminateEarly, +} + +/// Calculates an appropriate byte range for reading from an object based on the +/// provided metadata. +/// +/// This asynchronous function examines the `FileMeta` of an object in an object store +/// and determines the range of bytes to be read. The range calculation may adjust +/// the start and end points to align with meaningful data boundaries (like newlines). +/// +/// Returns a `Result` wrapping a `RangeCalculation`, which is either a calculated byte range or an indication to terminate early. +/// +/// Returns an `Error` if any part of the range calculation fails, such as issues in reading from the object store or invalid range boundaries. +async fn calculate_range( + file_meta: &FileMeta, + store: &Arc, +) -> Result { + let location = file_meta.location(); + let file_size = file_meta.object_meta.size; + + match file_meta.range { + None => Ok(RangeCalculation::Range(None)), + Some(FileRange { start, end }) => { + let (start, end) = (start as usize, end as usize); + + let start_delta = if start != 0 { + find_first_newline(store, location, start - 1, file_size).await? + } else { + 0 + }; + + let end_delta = if end != file_size { + find_first_newline(store, location, end - 1, file_size).await? + } else { + 0 + }; + + let range = start + start_delta..end + end_delta; + + if range.start == range.end { + return Ok(RangeCalculation::TerminateEarly); + } + + Ok(RangeCalculation::Range(Some(range))) + } + } +} + +/// Asynchronously finds the position of the first newline character in a specified byte range +/// within an object, such as a file, in an object store. +/// +/// This function scans the contents of the object starting from the specified `start` position +/// up to the `end` position, looking for the first occurrence of a newline (`'\n'`) character. +/// It returns the position of the first newline relative to the start of the range. +/// +/// Returns a `Result` wrapping a `usize` that represents the position of the first newline character found within the specified range. If no newline is found, it returns the length of the scanned data, effectively indicating the end of the range. +/// +/// The function returns an `Error` if any issues arise while reading from the object store or processing the data stream. +/// +async fn find_first_newline( + object_store: &Arc, + location: &Path, + start: usize, + end: usize, +) -> Result { + let range = Some(Range { start, end }); + + let options = GetOptions { + range, + ..Default::default() }; - result + + let result = object_store.get_opts(location, options).await?; + let mut result_stream = result.into_stream(); + + let mut index = 0; + + while let Some(chunk) = result_stream.next().await.transpose()? { + if let Some(position) = chunk.iter().position(|&byte| byte == b'\n') { + return Ok(index + position); + } + + index += chunk.len(); + } + + Ok(index) } #[cfg(test)] @@ -787,6 +898,7 @@ mod tests { last_modified: Utc::now(), size: 42, e_tag: None, + version: None, }; PartitionedFile { @@ -796,349 +908,4 @@ mod tests { extensions: None, } } - - /// Unit tests for `repartition_file_groups()` - mod repartition_file_groups_test { - use datafusion_common::Statistics; - use itertools::Itertools; - - use super::*; - - /// Empty file won't get partitioned - #[tokio::test] - async fn repartition_empty_file_only() { - let partitioned_file_empty = PartitionedFile::new("empty".to_string(), 0); - let file_group = vec![vec![partitioned_file_empty]]; - - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: file_group, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let partitioned_file = parquet_exec - .get_repartitioned(4, 0) - .base_config() - .file_groups - .clone(); - - assert!(partitioned_file[0][0].range.is_none()); - } - - // Repartition when there is a empty file in file groups - #[tokio::test] - async fn repartition_empty_files() { - let partitioned_file_a = PartitionedFile::new("a".to_string(), 10); - let partitioned_file_b = PartitionedFile::new("b".to_string(), 10); - let partitioned_file_empty = PartitionedFile::new("empty".to_string(), 0); - - let empty_first = vec![ - vec![partitioned_file_empty.clone()], - vec![partitioned_file_a.clone()], - vec![partitioned_file_b.clone()], - ]; - let empty_middle = vec![ - vec![partitioned_file_a.clone()], - vec![partitioned_file_empty.clone()], - vec![partitioned_file_b.clone()], - ]; - let empty_last = vec![ - vec![partitioned_file_a], - vec![partitioned_file_b], - vec![partitioned_file_empty], - ]; - - // Repartition file groups into x partitions - let expected_2 = - vec![(0, "a".to_string(), 0, 10), (1, "b".to_string(), 0, 10)]; - let expected_3 = vec![ - (0, "a".to_string(), 0, 7), - (1, "a".to_string(), 7, 10), - (1, "b".to_string(), 0, 4), - (2, "b".to_string(), 4, 10), - ]; - - //let file_groups_testset = [empty_first, empty_middle, empty_last]; - let file_groups_testset = [empty_first, empty_middle, empty_last]; - - for fg in file_groups_testset { - for (n_partition, expected) in [(2, &expected_2), (3, &expected_3)] { - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: fg.clone(), - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = file_groups_to_vec( - parquet_exec - .get_repartitioned(n_partition, 10) - .base_config() - .file_groups - .clone(), - ); - - assert_eq!(expected, &actual); - } - } - } - - #[tokio::test] - async fn repartition_single_file() { - // Single file, single partition into multiple partitions - let partitioned_file = PartitionedFile::new("a".to_string(), 123); - let single_partition = vec![vec![partitioned_file]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: single_partition, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = file_groups_to_vec( - parquet_exec - .get_repartitioned(4, 10) - .base_config() - .file_groups - .clone(), - ); - let expected = vec![ - (0, "a".to_string(), 0, 31), - (1, "a".to_string(), 31, 62), - (2, "a".to_string(), 62, 93), - (3, "a".to_string(), 93, 123), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_too_much_partitions() { - // Single file, single parittion into 96 partitions - let partitioned_file = PartitionedFile::new("a".to_string(), 8); - let single_partition = vec![vec![partitioned_file]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: single_partition, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = file_groups_to_vec( - parquet_exec - .get_repartitioned(96, 5) - .base_config() - .file_groups - .clone(), - ); - let expected = vec![ - (0, "a".to_string(), 0, 1), - (1, "a".to_string(), 1, 2), - (2, "a".to_string(), 2, 3), - (3, "a".to_string(), 3, 4), - (4, "a".to_string(), 4, 5), - (5, "a".to_string(), 5, 6), - (6, "a".to_string(), 6, 7), - (7, "a".to_string(), 7, 8), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_multiple_partitions() { - // Multiple files in single partition after redistribution - let partitioned_file_1 = PartitionedFile::new("a".to_string(), 40); - let partitioned_file_2 = PartitionedFile::new("b".to_string(), 60); - let source_partitions = - vec![vec![partitioned_file_1], vec![partitioned_file_2]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: source_partitions, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = file_groups_to_vec( - parquet_exec - .get_repartitioned(3, 10) - .base_config() - .file_groups - .clone(), - ); - let expected = vec![ - (0, "a".to_string(), 0, 34), - (1, "a".to_string(), 34, 40), - (1, "b".to_string(), 0, 28), - (2, "b".to_string(), 28, 60), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_same_num_partitions() { - // "Rebalance" files across partitions - let partitioned_file_1 = PartitionedFile::new("a".to_string(), 40); - let partitioned_file_2 = PartitionedFile::new("b".to_string(), 60); - let source_partitions = - vec![vec![partitioned_file_1], vec![partitioned_file_2]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: source_partitions, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = file_groups_to_vec( - parquet_exec - .get_repartitioned(2, 10) - .base_config() - .file_groups - .clone(), - ); - let expected = vec![ - (0, "a".to_string(), 0, 40), - (0, "b".to_string(), 0, 10), - (1, "b".to_string(), 10, 60), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_no_action_ranges() { - // No action due to Some(range) in second file - let partitioned_file_1 = PartitionedFile::new("a".to_string(), 123); - let mut partitioned_file_2 = PartitionedFile::new("b".to_string(), 144); - partitioned_file_2.range = Some(FileRange { start: 1, end: 50 }); - - let source_partitions = - vec![vec![partitioned_file_1], vec![partitioned_file_2]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: source_partitions, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = parquet_exec - .get_repartitioned(65, 10) - .base_config() - .file_groups - .clone(); - assert_eq!(2, actual.len()); - } - - #[tokio::test] - async fn repartition_no_action_min_size() { - // No action due to target_partition_size - let partitioned_file = PartitionedFile::new("a".to_string(), 123); - let single_partition = vec![vec![partitioned_file]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: single_partition, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = parquet_exec - .get_repartitioned(65, 500) - .base_config() - .file_groups - .clone(); - assert_eq!(1, actual.len()); - } - - fn file_groups_to_vec( - file_groups: Vec>, - ) -> Vec<(usize, String, i64, i64)> { - file_groups - .iter() - .enumerate() - .flat_map(|(part_idx, files)| { - files - .iter() - .map(|f| { - ( - part_idx, - f.object_meta.location.to_string(), - f.range.as_ref().unwrap().start, - f.range.as_ref().unwrap().end, - ) - }) - .collect_vec() - }) - .collect_vec() - } - } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs similarity index 87% rename from datafusion/core/src/datasource/physical_plan/parquet.rs rename to datafusion/core/src/datasource/physical_plan/parquet/mod.rs index d16c79a9692c..9d81d8d083c2 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -17,12 +17,17 @@ //! Execution plan for reading Parquet files +use std::any::Any; +use std::fmt::Debug; +use std::ops::Range; +use std::sync::Arc; + use crate::datasource::physical_plan::file_stream::{ FileOpenFuture, FileOpener, FileStream, }; use crate::datasource::physical_plan::{ - parquet::page_filter::PagePruningPredicate, DisplayAs, FileMeta, FileScanConfig, - SchemaAdapter, + parquet::page_filter::PagePruningPredicate, DisplayAs, FileGroupPartitioner, + FileMeta, FileScanConfig, SchemaAdapter, }; use crate::{ config::ConfigOptions, @@ -36,27 +41,18 @@ use crate::{ Statistics, }, }; -use datafusion_physical_expr::{ - ordering_equivalence_properties_helper, PhysicalSortExpr, -}; -use fmt::Debug; -use object_store::path::Path; -use std::any::Any; -use std::fmt; -use std::ops::Range; -use std::sync::Arc; -use tokio::task::JoinSet; use arrow::datatypes::{DataType, SchemaRef}; use arrow::error::ArrowError; use datafusion_physical_expr::{ - LexOrdering, OrderingEquivalenceProperties, PhysicalExpr, + EquivalenceProperties, LexOrdering, PhysicalExpr, PhysicalSortExpr, }; use bytes::Bytes; use futures::future::BoxFuture; use futures::{StreamExt, TryStreamExt}; use log::debug; +use object_store::path::Path; use object_store::ObjectStore; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::async_reader::{AsyncFileReader, ParquetObjectReader}; @@ -64,11 +60,13 @@ use parquet::arrow::{AsyncArrowWriter, ParquetRecordBatchStreamBuilder, Projecti use parquet::basic::{ConvertedType, LogicalType}; use parquet::file::{metadata::ParquetMetaData, properties::WriterProperties}; use parquet::schema::types::ColumnDescriptor; +use tokio::task::JoinSet; mod metrics; pub mod page_filter; mod row_filter; mod row_groups; +mod statistics; pub use metrics::ParquetFileMetrics; @@ -84,6 +82,9 @@ pub struct ParquetExec { /// Override for `Self::with_enable_page_index`. If None, uses /// values from base_config enable_page_index: Option, + /// Override for `Self::with_enable_bloom_filter`. If None, uses + /// values from base_config + enable_bloom_filter: Option, /// Base configuration for this scan base_config: FileScanConfig, projected_statistics: Statistics, @@ -153,6 +154,7 @@ impl ParquetExec { pushdown_filters: None, reorder_filters: None, enable_page_index: None, + enable_bloom_filter: None, base_config, projected_schema, projected_statistics, @@ -246,24 +248,16 @@ impl ParquetExec { .unwrap_or(config_options.execution.parquet.enable_page_index) } - /// Redistribute files across partitions according to their size - /// See comments on `get_file_groups_repartitioned()` for more detail. - pub fn get_repartitioned( - &self, - target_partitions: usize, - repartition_file_min_size: usize, - ) -> Self { - let repartitioned_file_groups_option = FileScanConfig::repartition_file_groups( - self.base_config.file_groups.clone(), - target_partitions, - repartition_file_min_size, - ); + /// If enabled, the reader will read by the bloom filter + pub fn with_enable_bloom_filter(mut self, enable_bloom_filter: bool) -> Self { + self.enable_bloom_filter = Some(enable_bloom_filter); + self + } - let mut new_plan = self.clone(); - if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { - new_plan.base_config.file_groups = repartitioned_file_groups; - } - new_plan + /// Return the value described in [`Self::with_enable_bloom_filter`] + fn enable_bloom_filter(&self, config_options: &ConfigOptions) -> bool { + self.enable_bloom_filter + .unwrap_or(config_options.execution.parquet.bloom_filter_enabled) } } @@ -321,8 +315,8 @@ impl ExecutionPlan for ParquetExec { .map(|ordering| ordering.as_slice()) } - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - ordering_equivalence_properties_helper( + fn equivalence_properties(&self) -> EquivalenceProperties { + EquivalenceProperties::new_with_orderings( self.schema(), &self.projected_output_ordering, ) @@ -335,6 +329,27 @@ impl ExecutionPlan for ParquetExec { Ok(self) } + /// Redistribute files across partitions according to their size + /// See comments on [`FileGroupPartitioner`] for more detail. + fn repartitioned( + &self, + target_partitions: usize, + config: &ConfigOptions, + ) -> Result>> { + let repartition_file_min_size = config.optimizer.repartition_file_min_size; + let repartitioned_file_groups_option = FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_repartition_file_min_size(repartition_file_min_size) + .with_preserve_order_within_groups(self.output_ordering().is_some()) + .repartition_file_groups(&self.base_config.file_groups); + + let mut new_plan = self.clone(); + if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { + new_plan.base_config.file_groups = repartitioned_file_groups; + } + Ok(Some(Arc::new(new_plan))) + } + fn execute( &self, partition_index: usize, @@ -375,6 +390,7 @@ impl ExecutionPlan for ParquetExec { pushdown_filters: self.pushdown_filters(config_options), reorder_filters: self.reorder_filters(config_options), enable_page_index: self.enable_page_index(config_options), + enable_bloom_filter: self.enable_bloom_filter(config_options), }; let stream = @@ -387,8 +403,8 @@ impl ExecutionPlan for ParquetExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - self.projected_statistics.clone() + fn statistics(&self) -> Result { + Ok(self.projected_statistics.clone()) } } @@ -408,6 +424,7 @@ struct ParquetOpener { pushdown_filters: bool, reorder_filters: bool, enable_page_index: bool, + enable_bloom_filter: bool, } impl FileOpener for ParquetOpener { @@ -442,6 +459,7 @@ impl FileOpener for ParquetOpener { self.enable_page_index, &self.page_pruning_predicate, ); + let enable_bloom_filter = self.enable_bloom_filter; let limit = self.limit; Ok(Box::pin(async move { @@ -450,8 +468,10 @@ impl FileOpener for ParquetOpener { ParquetRecordBatchStreamBuilder::new_with_options(reader, options) .await?; + let file_schema = builder.schema().clone(); + let (schema_mapping, adapted_projections) = - schema_adapter.map_schema(builder.schema())?; + schema_adapter.map_schema(&file_schema)?; // let predicate = predicate.map(|p| reassign_predicate_columns(p, builder.schema(), true)).transpose()?; let mask = ProjectionMask::roots( @@ -463,8 +483,8 @@ impl FileOpener for ParquetOpener { if let Some(predicate) = pushdown_filters.then_some(predicate).flatten() { let row_filter = row_filter::build_row_filter( &predicate, - builder.schema().as_ref(), - table_schema.as_ref(), + &file_schema, + &table_schema, builder.metadata(), reorder_predicates, &file_metrics, @@ -484,16 +504,35 @@ impl FileOpener for ParquetOpener { }; }; - // Row group pruning: attempt to skip entire row_groups + // Row group pruning by statistics: attempt to skip entire row_groups // using metadata on the row groups - let file_metadata = builder.metadata(); - let row_groups = row_groups::prune_row_groups( + let file_metadata = builder.metadata().clone(); + let predicate = pruning_predicate.as_ref().map(|p| p.as_ref()); + let mut row_groups = row_groups::prune_row_groups_by_statistics( + &file_schema, + builder.parquet_schema(), file_metadata.row_groups(), file_range, - pruning_predicate.as_ref().map(|p| p.as_ref()), + predicate, &file_metrics, ); + // Bloom filter pruning: if bloom filters are enabled and then attempt to skip entire row_groups + // using bloom filters on the row groups + if enable_bloom_filter && !row_groups.is_empty() { + if let Some(predicate) = predicate { + row_groups = row_groups::prune_row_groups_by_bloom_filters( + &file_schema, + &mut builder, + &row_groups, + file_metadata.row_groups(), + predicate, + &file_metrics, + ) + .await; + } + } + // page index pruning: if all data on individual pages can // be ruled using page metadata, rows from other columns // with that range can be skipped as well @@ -569,7 +608,7 @@ impl DefaultParquetFileReaderFactory { } /// Implements [`AsyncFileReader`] for a parquet file in object storage -struct ParquetFileReader { +pub(crate) struct ParquetFileReader { file_metrics: ParquetFileMetrics, inner: ParquetObjectReader, } @@ -685,28 +724,6 @@ pub async fn plan_to_parquet( Ok(()) } -// Copy from the arrow-rs -// https://github.com/apache/arrow-rs/blob/733b7e7fd1e8c43a404c3ce40ecf741d493c21b4/parquet/src/arrow/buffer/bit_util.rs#L55 -// Convert the byte slice to fixed length byte array with the length of 16 -fn sign_extend_be(b: &[u8]) -> [u8; 16] { - assert!(b.len() <= 16, "Array too large, expected less than 16"); - let is_negative = (b[0] & 128u8) == 128u8; - let mut result = if is_negative { [255u8; 16] } else { [0u8; 16] }; - for (d, s) in result.iter_mut().skip(16 - b.len()).zip(b) { - *d = *s; - } - result -} - -// Convert the bytes array to i128. -// The endian of the input bytes array must be big-endian. -pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 { - // The bytes array are from parquet file and must be the big-endian. - // The endian is defined by parquet format, and the reference document - // https://github.com/apache/parquet-format/blob/54e53e5d7794d383529dd30746378f19a12afd58/src/main/thrift/parquet.thrift#L66 - i128::from_be_bytes(sign_extend_be(b)) -} - // Convert parquet column schema to arrow data type, and just consider the // decimal data type. pub(crate) fn parquet_to_arrow_decimal_type( @@ -736,7 +753,7 @@ mod tests { use crate::datasource::file_format::options::CsvReadOptions; use crate::datasource::file_format::parquet::test_util::store_parquet; use crate::datasource::file_format::test_util::scan_format; - use crate::datasource::listing::{FileRange, PartitionedFile}; + use crate::datasource::listing::{FileRange, ListingOptions, PartitionedFile}; use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::context::SessionState; use crate::physical_plan::displayable; @@ -756,8 +773,8 @@ mod tests { }; use arrow_array::Date64Array; use chrono::{TimeZone, Utc}; - use datafusion_common::ScalarValue; use datafusion_common::{assert_contains, ToDFSchema}; + use datafusion_common::{FileType, GetExt, ScalarValue}; use datafusion_expr::{col, lit, when, Expr}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; @@ -860,13 +877,12 @@ mod tests { FileScanConfig { object_store_url: ObjectStoreUrl::local_filesystem(), file_groups: vec![file_groups], + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection, limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, predicate, None, @@ -1517,13 +1533,12 @@ mod tests { FileScanConfig { object_store_url: ObjectStoreUrl::local_filesystem(), file_groups, + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -1590,11 +1605,11 @@ mod tests { let partitioned_file = PartitionedFile { object_meta: meta, partition_values: vec![ - ScalarValue::Utf8(Some("2021".to_owned())), + ScalarValue::from("2021"), ScalarValue::UInt8(Some(10)), ScalarValue::Dictionary( Box::new(DataType::UInt16), - Box::new(ScalarValue::Utf8(Some("26".to_owned()))), + Box::new(ScalarValue::from("26")), ), ], range: None, @@ -1620,24 +1635,24 @@ mod tests { FileScanConfig { object_store_url, file_groups: vec![vec![partitioned_file]], - file_schema: schema, - statistics: Statistics::default(), + file_schema: schema.clone(), + statistics: Statistics::new_unknown(&schema), // file has 10 cols so index 12 should be month and 13 should be day projection: Some(vec![0, 1, 2, 12, 13]), limit: None, table_partition_cols: vec![ - ("year".to_owned(), DataType::Utf8), - ("month".to_owned(), DataType::UInt8), - ( - "day".to_owned(), + Field::new("year", DataType::Utf8, false), + Field::new("month", DataType::UInt8, false), + Field::new( + "day", DataType::Dictionary( Box::new(DataType::UInt16), Box::new(DataType::Utf8), ), + false, ), ], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -1684,6 +1699,7 @@ mod tests { last_modified: Utc.timestamp_nanos(0), size: 1337, e_tag: None, + version: None, }, partition_values: vec![], range: None, @@ -1695,12 +1711,11 @@ mod tests { object_store_url: ObjectStoreUrl::local_filesystem(), file_groups: vec![vec![partitioned_file]], file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), + statistics: Statistics::new_unknown(&Schema::empty()), projection: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -1753,8 +1768,9 @@ mod tests { ); } - #[tokio::test] - async fn parquet_exec_metrics() { + /// Returns a string array with contents: + /// "[Foo, null, bar, bar, bar, bar, zzz]" + fn string_batch() -> RecordBatch { let c1: ArrayRef = Arc::new(StringArray::from(vec![ Some("Foo"), None, @@ -1766,9 +1782,15 @@ mod tests { ])); // batch1: c1(string) - let batch1 = create_batch(vec![("c1", c1.clone())]); + create_batch(vec![("c1", c1.clone())]) + } - // on + #[tokio::test] + async fn parquet_exec_metrics() { + // batch1: c1(string) + let batch1 = string_batch(); + + // c1 != 'bar' let filter = col("c1").not_eq(lit("bar")); // read/write them files: @@ -1797,20 +1819,10 @@ mod tests { #[tokio::test] async fn parquet_exec_display() { - let c1: ArrayRef = Arc::new(StringArray::from(vec![ - Some("Foo"), - None, - Some("bar"), - Some("bar"), - Some("bar"), - Some("bar"), - Some("zzz"), - ])); - // batch1: c1(string) - let batch1 = create_batch(vec![("c1", c1.clone())]); + let batch1 = string_batch(); - // on + // c1 != 'bar' let filter = col("c1").not_eq(lit("bar")); let rt = RoundTrip::new() @@ -1839,21 +1851,15 @@ mod tests { } #[tokio::test] - async fn parquet_exec_skip_empty_pruning() { - let c1: ArrayRef = Arc::new(StringArray::from(vec![ - Some("Foo"), - None, - Some("bar"), - Some("bar"), - Some("bar"), - Some("bar"), - Some("zzz"), - ])); - + async fn parquet_exec_has_no_pruning_predicate_if_can_not_prune() { // batch1: c1(string) - let batch1 = create_batch(vec![("c1", c1.clone())]); + let batch1 = string_batch(); + + // filter is too complicated for pruning (PruningPredicate code does not + // handle case expressions), so the pruning predicate will always be + // "true" - // filter is too complicated for pruning + // WHEN c1 != bar THEN true ELSE false END let filter = when(col("c1").not_eq(lit("bar")), lit(true)) .otherwise(lit(false)) .unwrap(); @@ -1864,7 +1870,7 @@ mod tests { .round_trip(vec![batch1]) .await; - // Should not contain a pruning predicate + // Should not contain a pruning predicate (since nothing can be pruned) let pruning_predicate = &rt.parquet_exec.pruning_predicate; assert!( pruning_predicate.is_none(), @@ -1877,6 +1883,33 @@ mod tests { assert_eq!(predicate.unwrap().to_string(), filter_phys.to_string()); } + #[tokio::test] + async fn parquet_exec_has_pruning_predicate_for_guarantees() { + // batch1: c1(string) + let batch1 = string_batch(); + + // part of the filter is too complicated for pruning (PruningPredicate code does not + // handle case expressions), but part (c1 = 'foo') can be used for bloom filtering, so + // should still have the pruning predicate. + + // c1 = 'foo' AND (WHEN c1 != bar THEN true ELSE false END) + let filter = col("c1").eq(lit("foo")).and( + when(col("c1").not_eq(lit("bar")), lit(true)) + .otherwise(lit(false)) + .unwrap(), + ); + + let rt = RoundTrip::new() + .with_predicate(filter.clone()) + .with_pushdown_predicate() + .round_trip(vec![batch1]) + .await; + + // Should have a pruning predicate + let pruning_predicate = &rt.parquet_exec.pruning_predicate; + assert!(pruning_predicate.is_some()); + } + /// returns the sum of all the metrics with the specified name /// the returned set. /// @@ -1924,12 +1957,13 @@ mod tests { } #[tokio::test] - async fn write_parquet_results() -> Result<()> { + async fn write_table_results() -> Result<()> { // create partitioned input file and context let tmp_dir = TempDir::new()?; // let mut ctx = create_ctx(&tmp_dir, 4).await?; - let ctx = - SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(8), + ); let schema = populate_csv_partitions(&tmp_dir, 4, ".csv")?; // register csv file with the execution context ctx.register_csv( @@ -1944,13 +1978,29 @@ mod tests { let local_url = Url::parse("file://local").unwrap(); ctx.runtime_env().register_object_store(&local_url, local); + // Configure listing options + let file_format = ParquetFormat::default().with_enable_pruning(Some(true)); + let listing_options = ListingOptions::new(Arc::new(file_format)) + .with_file_extension(FileType::PARQUET.get_ext()); + // execute a simple query and write the results to parquet let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; - let out_dir_url = "file://local/out"; + std::fs::create_dir(&out_dir).unwrap(); let df = ctx.sql("SELECT c1, c2 FROM test").await?; - df.write_parquet(out_dir_url, DataFrameWriteOptions::new(), None) + let schema: Schema = df.schema().into(); + // Register a listing table - this will use all files in the directory as data sources + // for the query + ctx.register_listing_table( + "my_table", + &out_dir, + listing_options, + Some(Arc::new(schema)), + None, + ) + .await + .unwrap(); + df.write_table("my_table", DataFrameWriteOptions::new()) .await?; - // write_parquet(&mut ctx, "SELECT c1, c2 FROM test", &out_dir, None).await?; // create a new context and verify that the results were saved to a partitioned parquet file let ctx = SessionContext::new(); @@ -1966,7 +2016,6 @@ mod tests { .to_str() .expect("Should be a str") .to_owned(); - println!("{name}"); let (parsed_id, _) = name.split_once('_').expect("File should contain _ !"); let write_id = parsed_id.to_owned(); @@ -1977,24 +2026,81 @@ mod tests { ParquetReadOptions::default(), ) .await?; - ctx.register_parquet( - "part1", - &format!("{out_dir}/{write_id}_1.parquet"), - ParquetReadOptions::default(), - ) - .await?; - ctx.register_parquet( - "part2", - &format!("{out_dir}/{write_id}_2.parquet"), - ParquetReadOptions::default(), + + ctx.register_parquet("allparts", &out_dir, ParquetReadOptions::default()) + .await?; + + let part0 = ctx.sql("SELECT c1, c2 FROM part0").await?.collect().await?; + let allparts = ctx + .sql("SELECT c1, c2 FROM allparts") + .await? + .collect() + .await?; + + let allparts_count: usize = allparts.iter().map(|batch| batch.num_rows()).sum(); + + assert_eq!(part0[0].schema(), allparts[0].schema()); + + assert_eq!(allparts_count, 40); + + Ok(()) + } + + #[tokio::test] + async fn write_parquet_results() -> Result<()> { + // create partitioned input file and context + let tmp_dir = TempDir::new()?; + // let mut ctx = create_ctx(&tmp_dir, 4).await?; + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(8), + ); + let schema = populate_csv_partitions(&tmp_dir, 4, ".csv")?; + // register csv file with the execution context + ctx.register_csv( + "test", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new().schema(&schema), ) .await?; + + // register a local file system object store for /tmp directory + let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); + let local_url = Url::parse("file://local").unwrap(); + ctx.runtime_env().register_object_store(&local_url, local); + + // execute a simple query and write the results to parquet + let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; + let out_dir_url = "file://local/out"; + let df = ctx.sql("SELECT c1, c2 FROM test").await?; + df.write_parquet(out_dir_url, DataFrameWriteOptions::new(), None) + .await?; + // write_parquet(&mut ctx, "SELECT c1, c2 FROM test", &out_dir, None).await?; + + // create a new context and verify that the results were saved to a partitioned parquet file + let ctx = SessionContext::new(); + + // get write_id + let mut paths = fs::read_dir(&out_dir).unwrap(); + let path = paths.next(); + let name = path + .unwrap()? + .path() + .file_name() + .expect("Should be a file name") + .to_str() + .expect("Should be a str") + .to_owned(); + let (parsed_id, _) = name.split_once('_').expect("File should contain _ !"); + let write_id = parsed_id.to_owned(); + + // register each partition as well as the top level dir ctx.register_parquet( - "part3", - &format!("{out_dir}/{write_id}_3.parquet"), + "part0", + &format!("{out_dir}/{write_id}_0.parquet"), ParquetReadOptions::default(), ) .await?; + ctx.register_parquet("allparts", &out_dir, ParquetReadOptions::default()) .await?; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index e5c1d8feb0ab..a0637f379610 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -23,7 +23,7 @@ use arrow::array::{ }; use arrow::datatypes::DataType; use arrow::{array::ArrayRef, datatypes::SchemaRef, error::ArrowError}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use log::{debug, trace}; @@ -37,11 +37,11 @@ use parquet::{ }, format::PageLocation, }; +use std::collections::HashSet; use std::sync::Arc; -use crate::datasource::physical_plan::parquet::{ - from_bytes_to_i128, parquet_to_arrow_decimal_type, -}; +use crate::datasource::physical_plan::parquet::parquet_to_arrow_decimal_type; +use crate::datasource::physical_plan::parquet::statistics::from_bytes_to_i128; use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; use super::metrics::ParquetFileMetrics; @@ -147,17 +147,19 @@ impl PagePruningPredicate { let file_offset_indexes = file_metadata.offset_index(); let file_page_indexes = file_metadata.column_index(); - let (file_offset_indexes, file_page_indexes) = - match (file_offset_indexes, file_page_indexes) { - (Some(o), Some(i)) => (o, i), - _ => { - trace!( - "skip page pruning due to lack of indexes. Have offset: {} file: {}", + let (file_offset_indexes, file_page_indexes) = match ( + file_offset_indexes, + file_page_indexes, + ) { + (Some(o), Some(i)) => (o, i), + _ => { + trace!( + "skip page pruning due to lack of indexes. Have offset: {}, column index: {}", file_offset_indexes.is_some(), file_page_indexes.is_some() ); - return Ok(None); - } - }; + return Ok(None); + } + }; let mut row_selections = Vec::with_capacity(page_index_predicates.len()); for predicate in page_index_predicates { @@ -370,7 +372,7 @@ fn prune_pages_in_one_row_group( } fn create_row_count_in_each_page( - location: &Vec, + location: &[PageLocation], num_rows: usize, ) -> Vec { let mut vec = Vec::with_capacity(location.len()); @@ -553,4 +555,12 @@ impl<'a> PruningStatistics for PagesPruningStatistics<'a> { ))), } } + + fn contained( + &self, + _column: &datafusion_common::Column, + _values: &HashSet, + ) -> Option { + None + } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index 0f4b09caeded..151ab5f657b1 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -21,7 +21,7 @@ use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{arrow_err, DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::utils::reassign_predicate_columns; use std::collections::BTreeSet; @@ -126,7 +126,7 @@ impl ArrowPredicate for DatafusionArrowPredicate { match self .physical_expr .evaluate(&batch) - .map(|v| v.into_array(batch.num_rows())) + .and_then(|v| v.into_array(batch.num_rows())) { Ok(array) => { let bool_arr = as_boolean_array(&array)?.clone(); @@ -243,7 +243,7 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { } Err(e) => { // If the column is not in the table schema, should throw the error - Err(DataFusionError::ArrowError(e)) + arrow_err!(e) } }; } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index c6e2c68d0211..24c65423dd4c 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -15,28 +15,29 @@ // specific language governing permissions and limitations // under the License. -use arrow::{ - array::ArrayRef, - datatypes::{DataType, Schema}, +use arrow::{array::ArrayRef, datatypes::Schema}; +use arrow_array::BooleanArray; +use arrow_schema::FieldRef; +use datafusion_common::{Column, ScalarValue}; +use parquet::file::metadata::ColumnChunkMetaData; +use parquet::schema::types::SchemaDescriptor; +use parquet::{ + arrow::{async_reader::AsyncFileReader, ParquetRecordBatchStreamBuilder}, + bloom_filter::Sbbf, + file::metadata::RowGroupMetaData, }; -use datafusion_common::Column; -use datafusion_common::ScalarValue; -use log::debug; +use std::collections::{HashMap, HashSet}; -use parquet::file::{ - metadata::RowGroupMetaData, statistics::Statistics as ParquetStatistics, -}; - -use crate::datasource::physical_plan::parquet::{ - from_bytes_to_i128, parquet_to_arrow_decimal_type, -}; -use crate::{ - datasource::listing::FileRange, - physical_optimizer::pruning::{PruningPredicate, PruningStatistics}, +use crate::datasource::listing::FileRange; +use crate::datasource::physical_plan::parquet::statistics::{ + max_statistics, min_statistics, parquet_column, }; +use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; use super::ParquetFileMetrics; +/// Prune row groups based on statistics +/// /// Returns a vector of indexes into `groups` which should be scanned. /// /// If an index is NOT present in the returned Vec it means the @@ -44,7 +45,12 @@ use super::ParquetFileMetrics; /// /// If an index IS present in the returned Vec it means the predicate /// did not filter out that row group. -pub(crate) fn prune_row_groups( +/// +/// Note: This method currently ignores ColumnOrder +/// +pub(crate) fn prune_row_groups_by_statistics( + arrow_schema: &Schema, + parquet_schema: &SchemaDescriptor, groups: &[RowGroupMetaData], range: Option, predicate: Option<&PruningPredicate>, @@ -67,8 +73,9 @@ pub(crate) fn prune_row_groups( if let Some(predicate) = predicate { let pruning_stats = RowGroupPruningStatistics { + parquet_schema, row_group_metadata: metadata, - parquet_schema: predicate.schema().as_ref(), + arrow_schema, }; match predicate.prune(&pruning_stats) { Ok(values) => { @@ -81,7 +88,7 @@ pub(crate) fn prune_row_groups( // stats filter array could not be built // return a closure which will not filter out any row groups Err(e) => { - debug!("Error evaluating row group predicate values {e}"); + log::debug!("Error evaluating row group predicate values {e}"); metrics.predicate_evaluation_errors.add(1); } } @@ -92,146 +99,171 @@ pub(crate) fn prune_row_groups( filtered } -/// Wraps parquet statistics in a way -/// that implements [`PruningStatistics`] -struct RowGroupPruningStatistics<'a> { - row_group_metadata: &'a RowGroupMetaData, - parquet_schema: &'a Schema, -} +/// Prune row groups by bloom filters +/// +/// Returns a vector of indexes into `groups` which should be scanned. +/// +/// If an index is NOT present in the returned Vec it means the +/// predicate filtered all the row group. +/// +/// If an index IS present in the returned Vec it means the predicate +/// did not filter out that row group. +pub(crate) async fn prune_row_groups_by_bloom_filters< + T: AsyncFileReader + Send + 'static, +>( + arrow_schema: &Schema, + builder: &mut ParquetRecordBatchStreamBuilder, + row_groups: &[usize], + groups: &[RowGroupMetaData], + predicate: &PruningPredicate, + metrics: &ParquetFileMetrics, +) -> Vec { + let mut filtered = Vec::with_capacity(groups.len()); + for idx in row_groups { + // get all columns in the predicate that we could use a bloom filter with + let literal_columns = predicate.literal_columns(); + let mut column_sbbf = HashMap::with_capacity(literal_columns.len()); -/// Extract the min/max statistics from a `ParquetStatistics` object -macro_rules! get_statistic { - ($column_statistics:expr, $func:ident, $bytes_func:ident, $target_arrow_type:expr) => {{ - if !$column_statistics.has_min_max_set() { - return None; - } - match $column_statistics { - ParquetStatistics::Boolean(s) => Some(ScalarValue::Boolean(Some(*s.$func()))), - ParquetStatistics::Int32(s) => { - match $target_arrow_type { - // int32 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(*s.$func() as i128), - precision, - scale, - )) - } - _ => Some(ScalarValue::Int32(Some(*s.$func()))), - } - } - ParquetStatistics::Int64(s) => { - match $target_arrow_type { - // int64 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(*s.$func() as i128), - precision, - scale, - )) - } - _ => Some(ScalarValue::Int64(Some(*s.$func()))), - } - } - // 96 bit ints not supported - ParquetStatistics::Int96(_) => None, - ParquetStatistics::Float(s) => Some(ScalarValue::Float32(Some(*s.$func()))), - ParquetStatistics::Double(s) => Some(ScalarValue::Float64(Some(*s.$func()))), - ParquetStatistics::ByteArray(s) => { - match $target_arrow_type { - // decimal data type - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(from_bytes_to_i128(s.$bytes_func())), - precision, - scale, - )) - } - _ => { - let s = std::str::from_utf8(s.$bytes_func()) - .map(|s| s.to_string()) - .ok(); - Some(ScalarValue::Utf8(s)) - } - } - } - // type not supported yet - ParquetStatistics::FixedLenByteArray(s) => { - match $target_arrow_type { - // just support the decimal data type - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(from_bytes_to_i128(s.$bytes_func())), - precision, - scale, - )) - } - _ => None, + for column_name in literal_columns { + let Some((column_idx, _field)) = + parquet_column(builder.parquet_schema(), arrow_schema, &column_name) + else { + continue; + }; + + let bf = match builder + .get_row_group_column_bloom_filter(*idx, column_idx) + .await + { + Ok(Some(bf)) => bf, + Ok(None) => continue, // no bloom filter for this column + Err(e) => { + log::debug!("Ignoring error reading bloom filter: {e}"); + metrics.predicate_evaluation_errors.add(1); + continue; } + }; + column_sbbf.insert(column_name.to_string(), bf); + } + + let stats = BloomFilterStatistics { column_sbbf }; + + // Can this group be pruned? + let prune_group = match predicate.prune(&stats) { + Ok(values) => !values[0], + Err(e) => { + log::debug!("Error evaluating row group predicate on bloom filter: {e}"); + metrics.predicate_evaluation_errors.add(1); + false } + }; + + if prune_group { + metrics.row_groups_pruned.add(1); + } else { + filtered.push(*idx); } - }}; + } + filtered } -// Extract the min or max value calling `func` or `bytes_func` on the ParquetStatistics as appropriate -macro_rules! get_min_max_values { - ($self:expr, $column:expr, $func:ident, $bytes_func:ident) => {{ - let (_column_index, field) = - if let Some((v, f)) = $self.parquet_schema.column_with_name(&$column.name) { - (v, f) - } else { - // Named column was not present - return None; - }; +/// Implements `PruningStatistics` for Parquet Split Block Bloom Filters (SBBF) +struct BloomFilterStatistics { + /// Maps column name to the parquet bloom filter + column_sbbf: HashMap, +} - let data_type = field.data_type(); - // The result may be None, because DataFusion doesn't have support for ScalarValues of the column type - let null_scalar: ScalarValue = data_type.try_into().ok()?; +impl PruningStatistics for BloomFilterStatistics { + fn min_values(&self, _column: &Column) -> Option { + None + } - $self.row_group_metadata - .columns() + fn max_values(&self, _column: &Column) -> Option { + None + } + + fn num_containers(&self) -> usize { + 1 + } + + fn null_counts(&self, _column: &Column) -> Option { + None + } + + /// Use bloom filters to determine if we are sure this column can not + /// possibly contain `values` + /// + /// The `contained` API returns false if the bloom filters knows that *ALL* + /// of the values in a column are not present. + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + let sbbf = self.column_sbbf.get(column.name.as_str())?; + + // Bloom filters are probabilistic data structures that can return false + // positives (i.e. it might return true even if the value is not + // present) however, the bloom filter will return `false` if the value is + // definitely not present. + + let known_not_present = values .iter() - .find(|c| c.column_descr().name() == &$column.name) - .and_then(|c| if c.statistics().is_some() {Some((c.statistics().unwrap(), c.column_descr()))} else {None}) - .map(|(stats, column_descr)| - { - let target_data_type = parquet_to_arrow_decimal_type(column_descr); - get_statistic!(stats, $func, $bytes_func, target_data_type) - }) - .flatten() - // column either didn't have statistics at all or didn't have min/max values - .or_else(|| Some(null_scalar.clone())) - .map(|s| s.to_array()) - }} + .map(|value| match value { + ScalarValue::Utf8(Some(v)) => sbbf.check(&v.as_str()), + ScalarValue::Boolean(Some(v)) => sbbf.check(v), + ScalarValue::Float64(Some(v)) => sbbf.check(v), + ScalarValue::Float32(Some(v)) => sbbf.check(v), + ScalarValue::Int64(Some(v)) => sbbf.check(v), + ScalarValue::Int32(Some(v)) => sbbf.check(v), + ScalarValue::Int16(Some(v)) => sbbf.check(v), + ScalarValue::Int8(Some(v)) => sbbf.check(v), + _ => true, + }) + // The row group doesn't contain any of the values if + // all the checks are false + .all(|v| !v); + + let contains = if known_not_present { + Some(false) + } else { + // Given the bloom filter is probabilistic, we can't be sure that + // the row group actually contains the values. Return `None` to + // indicate this uncertainty + None + }; + + Some(BooleanArray::from(vec![contains])) + } } -// Extract the null count value on the ParquetStatistics -macro_rules! get_null_count_values { - ($self:expr, $column:expr) => {{ - let value = ScalarValue::UInt64( - if let Some(col) = $self - .row_group_metadata - .columns() - .iter() - .find(|c| c.column_descr().name() == &$column.name) - { - col.statistics().map(|s| s.null_count()) - } else { - Some($self.row_group_metadata.num_rows() as u64) - }, - ); +/// Wraps [`RowGroupMetaData`] in a way that implements [`PruningStatistics`] +/// +/// Note: This should be implemented for an array of [`RowGroupMetaData`] instead +/// of per row-group +struct RowGroupPruningStatistics<'a> { + parquet_schema: &'a SchemaDescriptor, + row_group_metadata: &'a RowGroupMetaData, + arrow_schema: &'a Schema, +} - Some(value.to_array()) - }}; +impl<'a> RowGroupPruningStatistics<'a> { + /// Lookups up the parquet column by name + fn column(&self, name: &str) -> Option<(&ColumnChunkMetaData, &FieldRef)> { + let (idx, field) = parquet_column(self.parquet_schema, self.arrow_schema, name)?; + Some((self.row_group_metadata.column(idx), field)) + } } impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { fn min_values(&self, column: &Column) -> Option { - get_min_max_values!(self, column, min, min_bytes) + let (column, field) = self.column(&column.name)?; + min_statistics(field.data_type(), std::iter::once(column.statistics())).ok() } fn max_values(&self, column: &Column) -> Option { - get_min_max_values!(self, column, max, max_bytes) + let (column, field) = self.column(&column.name)?; + max_statistics(field.data_type(), std::iter::once(column.statistics())).ok() } fn num_containers(&self) -> usize { @@ -239,21 +271,34 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { } fn null_counts(&self, column: &Column) -> Option { - get_null_count_values!(self, column) + let (c, _) = self.column(&column.name)?; + let scalar = ScalarValue::UInt64(Some(c.statistics()?.null_count())); + scalar.to_array().ok() + } + + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None } } #[cfg(test)] mod tests { use super::*; + use crate::datasource::physical_plan::parquet::ParquetFileReader; use crate::physical_plan::metrics::ExecutionPlanMetricsSet; use arrow::datatypes::DataType::Decimal128; use arrow::datatypes::Schema; use arrow::datatypes::{DataType, Field}; - use datafusion_common::ToDFSchema; + use datafusion_common::{Result, ToDFSchema}; use datafusion_expr::{cast, col, lit, Expr}; use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; + use parquet::arrow::arrow_to_parquet_schema; + use parquet::arrow::async_reader::ParquetObjectReader; use parquet::basic::LogicalType; use parquet::data_type::{ByteArray, FixedLenByteArray}; use parquet::{ @@ -310,11 +355,11 @@ mod tests { fn row_group_pruning_predicate_simple_expr() { use datafusion_expr::{col, lit}; // int > 1 => c1_max > 1 - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expr = col("c1").gt(lit(15)); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32); let schema_descr = get_test_schema_descr(vec![field]); @@ -329,7 +374,14 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( - prune_row_groups(&[rgm1, rgm2], None, Some(&pruning_predicate), &metrics), + prune_row_groups_by_statistics( + &schema, + &schema_descr, + &[rgm1, rgm2], + None, + Some(&pruning_predicate), + &metrics + ), vec![1] ); } @@ -338,11 +390,11 @@ mod tests { fn row_group_pruning_predicate_missing_stats() { use datafusion_expr::{col, lit}; // int > 1 => c1_max > 1 - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expr = col("c1").gt(lit(15)); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32); let schema_descr = get_test_schema_descr(vec![field]); @@ -358,7 +410,14 @@ mod tests { // missing statistics for first row group mean that the result from the predicate expression // is null / undefined so the first row group can't be filtered out assert_eq!( - prune_row_groups(&[rgm1, rgm2], None, Some(&pruning_predicate), &metrics), + prune_row_groups_by_statistics( + &schema, + &schema_descr, + &[rgm1, rgm2], + None, + Some(&pruning_predicate), + &metrics + ), vec![0, 1] ); } @@ -400,7 +459,14 @@ mod tests { // the first row group is still filtered out because the predicate expression can be partially evaluated // when conditions are joined using AND assert_eq!( - prune_row_groups(groups, None, Some(&pruning_predicate), &metrics), + prune_row_groups_by_statistics( + &schema, + &schema_descr, + groups, + None, + Some(&pruning_predicate), + &metrics + ), vec![1] ); @@ -408,16 +474,81 @@ mod tests { // this bypasses the entire predicate expression and no row groups are filtered out let expr = col("c1").gt(lit(15)).or(col("c2").rem(lit(2)).eq(lit(0))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); // if conditions in predicate are joined with OR and an unsupported expression is used // this bypasses the entire predicate expression and no row groups are filtered out assert_eq!( - prune_row_groups(groups, None, Some(&pruning_predicate), &metrics), + prune_row_groups_by_statistics( + &schema, + &schema_descr, + groups, + None, + Some(&pruning_predicate), + &metrics + ), vec![0, 1] ); } + #[test] + fn row_group_pruning_predicate_file_schema() { + use datafusion_expr::{col, lit}; + // test row group predicate when file schema is different than table schema + // c1 > 0 + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])); + let expr = col("c1").gt(lit(0)); + let expr = logical2physical(&expr, &table_schema); + let pruning_predicate = + PruningPredicate::try_new(expr, table_schema.clone()).unwrap(); + + // Model a file schema's column order c2 then c1, which is the opposite + // of the table schema + let file_schema = Arc::new(Schema::new(vec![ + Field::new("c2", DataType::Int32, false), + Field::new("c1", DataType::Int32, false), + ])); + let schema_descr = get_test_schema_descr(vec![ + PrimitiveTypeField::new("c2", PhysicalType::INT32), + PrimitiveTypeField::new("c1", PhysicalType::INT32), + ]); + // rg1 has c2 less than zero, c1 greater than zero + let rgm1 = get_row_group_meta_data( + &schema_descr, + vec![ + ParquetStatistics::int32(Some(-10), Some(-1), None, 0, false), // c2 + ParquetStatistics::int32(Some(1), Some(10), None, 0, false), + ], + ); + // rg1 has c2 greater than zero, c1 less than zero + let rgm2 = get_row_group_meta_data( + &schema_descr, + vec![ + ParquetStatistics::int32(Some(1), Some(10), None, 0, false), + ParquetStatistics::int32(Some(-10), Some(-1), None, 0, false), + ], + ); + + let metrics = parquet_file_metrics(); + let groups = &[rgm1, rgm2]; + // the first row group should be left because c1 is greater than zero + // the second should be filtered out because c1 is less than zero + assert_eq!( + prune_row_groups_by_statistics( + &file_schema, // NB must be file schema, not table_schema + &schema_descr, + groups, + None, + Some(&pruning_predicate), + &metrics + ), + vec![0] + ); + } + fn gen_row_group_meta_data_for_pruning_predicate() -> Vec { let schema_descr = get_test_schema_descr(vec![ PrimitiveTypeField::new("c1", PhysicalType::INT32), @@ -448,15 +579,23 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Boolean, false), ])); + let schema_descr = arrow_to_parquet_schema(&schema).unwrap(); let expr = col("c1").gt(lit(15)).and(col("c2").is_null()); let expr = logical2physical(&expr, &schema); - let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let groups = gen_row_group_meta_data_for_pruning_predicate(); let metrics = parquet_file_metrics(); // First row group was filtered out because it contains no null value on "c2". assert_eq!( - prune_row_groups(&groups, None, Some(&pruning_predicate), &metrics), + prune_row_groups_by_statistics( + &schema, + &schema_descr, + &groups, + None, + Some(&pruning_predicate), + &metrics + ), vec![1] ); } @@ -471,18 +610,26 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Boolean, false), ])); + let schema_descr = arrow_to_parquet_schema(&schema).unwrap(); let expr = col("c1") .gt(lit(15)) .and(col("c2").eq(lit(ScalarValue::Boolean(None)))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let groups = gen_row_group_meta_data_for_pruning_predicate(); let metrics = parquet_file_metrics(); // bool = NULL always evaluates to NULL (and thus will not // pass predicates. Ideally these should both be false assert_eq!( - prune_row_groups(&groups, None, Some(&pruning_predicate), &metrics), + prune_row_groups_by_statistics( + &schema, + &schema_descr, + &groups, + None, + Some(&pruning_predicate), + &metrics + ), vec![1] ); } @@ -495,8 +642,11 @@ mod tests { // INT32: c1 > 5, the c1 is decimal(9,2) // The type of scalar value if decimal(9,2), don't need to do cast - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(9, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(9, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -507,8 +657,7 @@ mod tests { let schema_descr = get_test_schema_descr(vec![field]); let expr = col("c1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, // [1.00, 6.00] @@ -535,7 +684,9 @@ mod tests { ); let metrics = parquet_file_metrics(); assert_eq!( - prune_row_groups( + prune_row_groups_by_statistics( + &schema, + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -548,8 +699,11 @@ mod tests { // The c1 type is decimal(9,0) in the parquet file, and the type of scalar is decimal(5,2). // We should convert all type to the coercion type, which is decimal(11,2) // The decimal of arrow is decimal(5,2), the decimal of parquet is decimal(9,0) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(9, 0), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(9, 0), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { @@ -564,8 +718,7 @@ mod tests { Decimal128(11, 2), )); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, // [100, 600] @@ -598,7 +751,9 @@ mod tests { ); let metrics = parquet_file_metrics(); assert_eq!( - prune_row_groups( + prune_row_groups_by_statistics( + &schema, + &schema_descr, &[rgm1, rgm2, rgm3, rgm4], None, Some(&pruning_predicate), @@ -608,8 +763,11 @@ mod tests { ); // INT64: c1 < 5, the c1 is decimal(18,2) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(18, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT64) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -620,8 +778,7 @@ mod tests { let schema_descr = get_test_schema_descr(vec![field]); let expr = col("c1").lt(lit(ScalarValue::Decimal128(Some(500), 18, 2))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, // [6.00, 8.00] @@ -645,7 +802,9 @@ mod tests { ); let metrics = parquet_file_metrics(); assert_eq!( - prune_row_groups( + prune_row_groups_by_statistics( + &schema, + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -656,8 +815,11 @@ mod tests { // FIXED_LENGTH_BYTE_ARRAY: c1 = decimal128(100000, 28, 3), the c1 is decimal(18,2) // the type of parquet is decimal(18,2) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(18, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -671,8 +833,7 @@ mod tests { let left = cast(col("c1"), DataType::Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); // we must use the big-endian when encode the i128 to bytes or vec[u8]. let rgm1 = get_row_group_meta_data( &schema_descr, @@ -715,7 +876,9 @@ mod tests { ); let metrics = parquet_file_metrics(); assert_eq!( - prune_row_groups( + prune_row_groups_by_statistics( + &schema, + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -726,8 +889,11 @@ mod tests { // BYTE_ARRAY: c1 = decimal128(100000, 28, 3), the c1 is decimal(18,2) // the type of parquet is decimal(18,2) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(18, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::BYTE_ARRAY) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -741,8 +907,7 @@ mod tests { let left = cast(col("c1"), DataType::Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); // we must use the big-endian when encode the i128 to bytes or vec[u8]. let rgm1 = get_row_group_meta_data( &schema_descr, @@ -774,7 +939,9 @@ mod tests { ); let metrics = parquet_file_metrics(); assert_eq!( - prune_row_groups( + prune_row_groups_by_statistics( + &schema, + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -788,7 +955,6 @@ mod tests { schema_descr: &SchemaDescPtr, column_statistics: Vec, ) -> RowGroupMetaData { - use parquet::file::metadata::ColumnChunkMetaData; let mut columns = vec![]; for (i, s) in column_statistics.iter().enumerate() { let column = ColumnChunkMetaData::builder(schema_descr.column(i)) @@ -806,7 +972,7 @@ mod tests { } fn get_test_schema_descr(fields: Vec) -> SchemaDescPtr { - use parquet::schema::types::{SchemaDescriptor, Type as SchemaType}; + use parquet::schema::types::Type as SchemaType; let schema_fields = fields .iter() .map(|field| { @@ -846,4 +1012,267 @@ mod tests { let execution_props = ExecutionProps::new(); create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_simple_expr() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_all_pruned() + // generate pruning predicate `(String = "Hello_Not_exists")` + .run(col(r#""String""#).eq(lit("Hello_Not_Exists"))) + .await + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_mutiple_expr() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_all_pruned() + // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` + .run( + lit("1").eq(lit("1")).and( + col(r#""String""#) + .eq(lit("Hello_Not_Exists")) + .or(col(r#""String""#).eq(lit("Hello_Not_Exists2"))), + ), + ) + .await + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_sql_in() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + + let expr = col(r#""String""#).in_list( + vec![ + lit("Hello_Not_Exists"), + lit("Hello_Not_Exists2"), + lit("Hello_Not_Exists3"), + lit("Hello_Not_Exist4"), + ], + false, + ); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert!(pruned_row_groups.is_empty()); + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_value() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello")` + .run(col(r#""String""#).eq(lit("Hello"))) + .await + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_2_values() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello") OR (String = "the quick")` + .run( + col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))), + ) + .await + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_3_values() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` + .run( + col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))) + .or(col(r#""String""#).eq(lit("are you"))), + ) + .await + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_or_not_eq() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "foo") OR (String != "bar")` + .run( + col(r#""String""#) + .not_eq(lit("foo")) + .or(col(r#""String""#).not_eq(lit("bar"))), + ) + .await + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_without_bloom_filter() { + // generate pruning predicate on a column without a bloom filter + BloomFilterTest::new_all_types() + .with_expect_none_pruned() + .run(col(r#""string_col""#).eq(lit("0"))) + .await + } + + struct BloomFilterTest { + file_name: String, + schema: Schema, + // which row groups should be attempted to prune + row_groups: Vec, + // which row groups are expected to be left after pruning. Must be set + // otherwise will panic on run() + post_pruning_row_groups: Option>, + } + + impl BloomFilterTest { + /// Return a test for data_index_bloom_encoding_stats.parquet + /// Note the values in the `String` column are: + /// ```sql + /// ❯ select * from './parquet-testing/data/data_index_bloom_encoding_stats.parquet'; + /// +-----------+ + /// | String | + /// +-----------+ + /// | Hello | + /// | This is | + /// | a | + /// | test | + /// | How | + /// | are you | + /// | doing | + /// | today | + /// | the quick | + /// | brown fox | + /// | jumps | + /// | over | + /// | the lazy | + /// | dog | + /// +-----------+ + /// ``` + fn new_data_index_bloom_encoding_stats() -> Self { + Self { + file_name: String::from("data_index_bloom_encoding_stats.parquet"), + schema: Schema::new(vec![Field::new("String", DataType::Utf8, false)]), + row_groups: vec![0], + post_pruning_row_groups: None, + } + } + + // Return a test for alltypes_plain.parquet + fn new_all_types() -> Self { + Self { + file_name: String::from("alltypes_plain.parquet"), + schema: Schema::new(vec![Field::new( + "string_col", + DataType::Utf8, + false, + )]), + row_groups: vec![0], + post_pruning_row_groups: None, + } + } + + /// Expect all row groups to be pruned + pub fn with_expect_all_pruned(mut self) -> Self { + self.post_pruning_row_groups = Some(vec![]); + self + } + + /// Expect all row groups not to be pruned + pub fn with_expect_none_pruned(mut self) -> Self { + self.post_pruning_row_groups = Some(self.row_groups.clone()); + self + } + + /// Prune this file using the specified expression and check that the expected row groups are left + async fn run(self, expr: Expr) { + let Self { + file_name, + schema, + row_groups, + post_pruning_row_groups, + } = self; + + let post_pruning_row_groups = + post_pruning_row_groups.expect("post_pruning_row_groups must be set"); + + let testdata = datafusion_common::test_util::parquet_test_data(); + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + &file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, post_pruning_row_groups); + } + } + + async fn test_row_group_bloom_filter_pruning_predicate( + file_name: &str, + data: bytes::Bytes, + pruning_predicate: &PruningPredicate, + row_groups: &[usize], + ) -> Result> { + use object_store::{ObjectMeta, ObjectStore}; + + let object_meta = ObjectMeta { + location: object_store::path::Path::parse(file_name).expect("creating path"), + last_modified: chrono::DateTime::from(std::time::SystemTime::now()), + size: data.len(), + e_tag: None, + version: None, + }; + let in_memory = object_store::memory::InMemory::new(); + in_memory + .put(&object_meta.location, data) + .await + .expect("put parquet file into in memory object store"); + + let metrics = ExecutionPlanMetricsSet::new(); + let file_metrics = + ParquetFileMetrics::new(0, object_meta.location.as_ref(), &metrics); + let reader = ParquetFileReader { + inner: ParquetObjectReader::new(Arc::new(in_memory), object_meta), + file_metrics: file_metrics.clone(), + }; + let mut builder = ParquetRecordBatchStreamBuilder::new(reader).await.unwrap(); + + let metadata = builder.metadata().clone(); + let pruned_row_group = prune_row_groups_by_bloom_filters( + pruning_predicate.schema(), + &mut builder, + row_groups, + metadata.row_groups(), + pruning_predicate, + &file_metrics, + ) + .await; + + Ok(pruned_row_group) + } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs new file mode 100644 index 000000000000..4e472606da51 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -0,0 +1,899 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`min_statistics`] and [`max_statistics`] convert statistics in parquet format to arrow [`ArrayRef`]. + +// TODO: potentially move this to arrow-rs: https://github.com/apache/arrow-rs/issues/4328 + +use arrow::{array::ArrayRef, datatypes::DataType}; +use arrow_array::new_empty_array; +use arrow_schema::{FieldRef, Schema}; +use datafusion_common::{Result, ScalarValue}; +use parquet::file::statistics::Statistics as ParquetStatistics; +use parquet::schema::types::SchemaDescriptor; + +// Convert the bytes array to i128. +// The endian of the input bytes array must be big-endian. +pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 { + // The bytes array are from parquet file and must be the big-endian. + // The endian is defined by parquet format, and the reference document + // https://github.com/apache/parquet-format/blob/54e53e5d7794d383529dd30746378f19a12afd58/src/main/thrift/parquet.thrift#L66 + i128::from_be_bytes(sign_extend_be(b)) +} + +// Copy from arrow-rs +// https://github.com/apache/arrow-rs/blob/733b7e7fd1e8c43a404c3ce40ecf741d493c21b4/parquet/src/arrow/buffer/bit_util.rs#L55 +// Convert the byte slice to fixed length byte array with the length of 16 +fn sign_extend_be(b: &[u8]) -> [u8; 16] { + assert!(b.len() <= 16, "Array too large, expected less than 16"); + let is_negative = (b[0] & 128u8) == 128u8; + let mut result = if is_negative { [255u8; 16] } else { [0u8; 16] }; + for (d, s) in result.iter_mut().skip(16 - b.len()).zip(b) { + *d = *s; + } + result +} + +/// Extract a single min/max statistics from a [`ParquetStatistics`] object +/// +/// * `$column_statistics` is the `ParquetStatistics` object +/// * `$func is the function` (`min`/`max`) to call to get the value +/// * `$bytes_func` is the function (`min_bytes`/`max_bytes`) to call to get the value as bytes +/// * `$target_arrow_type` is the [`DataType`] of the target statistics +macro_rules! get_statistic { + ($column_statistics:expr, $func:ident, $bytes_func:ident, $target_arrow_type:expr) => {{ + if !$column_statistics.has_min_max_set() { + return None; + } + match $column_statistics { + ParquetStatistics::Boolean(s) => Some(ScalarValue::Boolean(Some(*s.$func()))), + ParquetStatistics::Int32(s) => { + match $target_arrow_type { + // int32 to decimal with the precision and scale + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(*s.$func() as i128), + *precision, + *scale, + )) + } + _ => Some(ScalarValue::Int32(Some(*s.$func()))), + } + } + ParquetStatistics::Int64(s) => { + match $target_arrow_type { + // int64 to decimal with the precision and scale + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(*s.$func() as i128), + *precision, + *scale, + )) + } + _ => Some(ScalarValue::Int64(Some(*s.$func()))), + } + } + // 96 bit ints not supported + ParquetStatistics::Int96(_) => None, + ParquetStatistics::Float(s) => Some(ScalarValue::Float32(Some(*s.$func()))), + ParquetStatistics::Double(s) => Some(ScalarValue::Float64(Some(*s.$func()))), + ParquetStatistics::ByteArray(s) => { + match $target_arrow_type { + // decimal data type + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(from_bytes_to_i128(s.$bytes_func())), + *precision, + *scale, + )) + } + _ => { + let s = std::str::from_utf8(s.$bytes_func()) + .map(|s| s.to_string()) + .ok(); + Some(ScalarValue::Utf8(s)) + } + } + } + // type not supported yet + ParquetStatistics::FixedLenByteArray(s) => { + match $target_arrow_type { + // just support the decimal data type + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(from_bytes_to_i128(s.$bytes_func())), + *precision, + *scale, + )) + } + _ => None, + } + } + } + }}; +} + +/// Lookups up the parquet column by name +/// +/// Returns the parquet column index and the corresponding arrow field +pub(crate) fn parquet_column<'a>( + parquet_schema: &SchemaDescriptor, + arrow_schema: &'a Schema, + name: &str, +) -> Option<(usize, &'a FieldRef)> { + let (root_idx, field) = arrow_schema.fields.find(name)?; + if field.data_type().is_nested() { + // Nested fields are not supported and require non-trivial logic + // to correctly walk the parquet schema accounting for the + // logical type rules - + // + // For example a ListArray could correspond to anything from 1 to 3 levels + // in the parquet schema + return None; + } + + // This could be made more efficient (#TBD) + let parquet_idx = (0..parquet_schema.columns().len()) + .find(|x| parquet_schema.get_column_root_idx(*x) == root_idx)?; + Some((parquet_idx, field)) +} + +/// Extracts the min statistics from an iterator of [`ParquetStatistics`] to an [`ArrayRef`] +pub(crate) fn min_statistics<'a, I: Iterator>>( + data_type: &DataType, + iterator: I, +) -> Result { + let scalars = iterator + .map(|x| x.and_then(|s| get_statistic!(s, min, min_bytes, Some(data_type)))); + collect_scalars(data_type, scalars) +} + +/// Extracts the max statistics from an iterator of [`ParquetStatistics`] to an [`ArrayRef`] +pub(crate) fn max_statistics<'a, I: Iterator>>( + data_type: &DataType, + iterator: I, +) -> Result { + let scalars = iterator + .map(|x| x.and_then(|s| get_statistic!(s, max, max_bytes, Some(data_type)))); + collect_scalars(data_type, scalars) +} + +/// Builds an array from an iterator of ScalarValue +fn collect_scalars>>( + data_type: &DataType, + iterator: I, +) -> Result { + let mut scalars = iterator.peekable(); + match scalars.peek().is_none() { + true => Ok(new_empty_array(data_type)), + false => { + let null = ScalarValue::try_from(data_type)?; + ScalarValue::iter_to_array(scalars.map(|x| x.unwrap_or_else(|| null.clone()))) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow_array::{ + new_null_array, Array, BinaryArray, BooleanArray, Decimal128Array, Float32Array, + Float64Array, Int32Array, Int64Array, RecordBatch, StringArray, StructArray, + TimestampNanosecondArray, + }; + use arrow_schema::{Field, SchemaRef}; + use bytes::Bytes; + use datafusion_common::test_util::parquet_test_data; + use parquet::arrow::arrow_reader::ArrowReaderBuilder; + use parquet::arrow::arrow_writer::ArrowWriter; + use parquet::file::metadata::{ParquetMetaData, RowGroupMetaData}; + use parquet::file::properties::{EnabledStatistics, WriterProperties}; + use std::path::PathBuf; + use std::sync::Arc; + + // TODO error cases (with parquet statistics that are mismatched in expected type) + + #[test] + fn roundtrip_empty() { + let empty_bool_array = new_empty_array(&DataType::Boolean); + Test { + input: empty_bool_array.clone(), + expected_min: empty_bool_array.clone(), + expected_max: empty_bool_array.clone(), + } + .run() + } + + #[test] + fn roundtrip_bool() { + Test { + input: bool_array([ + // row group 1 + Some(true), + None, + Some(true), + // row group 2 + Some(true), + Some(false), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: bool_array([Some(true), Some(false), None]), + expected_max: bool_array([Some(true), Some(true), None]), + } + .run() + } + + #[test] + fn roundtrip_int32() { + Test { + input: i32_array([ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(0), + Some(5), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: i32_array([Some(1), Some(0), None]), + expected_max: i32_array([Some(3), Some(5), None]), + } + .run() + } + + #[test] + fn roundtrip_int64() { + Test { + input: i64_array([ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(0), + Some(5), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: i64_array([Some(1), Some(0), None]), + expected_max: i64_array(vec![Some(3), Some(5), None]), + } + .run() + } + + #[test] + fn roundtrip_f32() { + Test { + input: f32_array([ + // row group 1 + Some(1.0), + None, + Some(3.0), + // row group 2 + Some(-1.0), + Some(5.0), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: f32_array([Some(1.0), Some(-1.0), None]), + expected_max: f32_array([Some(3.0), Some(5.0), None]), + } + .run() + } + + #[test] + fn roundtrip_f64() { + Test { + input: f64_array([ + // row group 1 + Some(1.0), + None, + Some(3.0), + // row group 2 + Some(-1.0), + Some(5.0), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: f64_array([Some(1.0), Some(-1.0), None]), + expected_max: f64_array([Some(3.0), Some(5.0), None]), + } + .run() + } + + #[test] + #[should_panic( + expected = "Inconsistent types in ScalarValue::iter_to_array. Expected Int64, got TimestampNanosecond(NULL, None)" + )] + // Due to https://github.com/apache/arrow-datafusion/issues/8295 + fn roundtrip_timestamp() { + Test { + input: timestamp_array([ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(9), + Some(5), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: timestamp_array([Some(1), Some(5), None]), + expected_max: timestamp_array([Some(3), Some(9), None]), + } + .run() + } + + #[test] + fn roundtrip_decimal() { + Test { + input: Arc::new( + Decimal128Array::from(vec![ + // row group 1 + Some(100), + None, + Some(22000), + // row group 2 + Some(500000), + Some(330000), + None, + // row group 3 + None, + None, + None, + ]) + .with_precision_and_scale(9, 2) + .unwrap(), + ), + expected_min: Arc::new( + Decimal128Array::from(vec![Some(100), Some(330000), None]) + .with_precision_and_scale(9, 2) + .unwrap(), + ), + expected_max: Arc::new( + Decimal128Array::from(vec![Some(22000), Some(500000), None]) + .with_precision_and_scale(9, 2) + .unwrap(), + ), + } + .run() + } + + #[test] + fn roundtrip_utf8() { + Test { + input: utf8_array([ + // row group 1 + Some("A"), + None, + Some("Q"), + // row group 2 + Some("ZZ"), + Some("AA"), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: utf8_array([Some("A"), Some("AA"), None]), + expected_max: utf8_array([Some("Q"), Some("ZZ"), None]), + } + .run() + } + + #[test] + fn roundtrip_struct() { + let mut test = Test { + input: struct_array(vec![ + // row group 1 + (Some(true), Some(1)), + (None, None), + (Some(true), Some(3)), + // row group 2 + (Some(true), Some(0)), + (Some(false), Some(5)), + (None, None), + // row group 3 + (None, None), + (None, None), + (None, None), + ]), + expected_min: struct_array(vec![ + (Some(true), Some(1)), + (Some(true), Some(0)), + (None, None), + ]), + + expected_max: struct_array(vec![ + (Some(true), Some(3)), + (Some(true), Some(0)), + (None, None), + ]), + }; + // Due to https://github.com/apache/arrow-datafusion/issues/8334, + // statistics for struct arrays are not supported + test.expected_min = + new_null_array(test.input.data_type(), test.expected_min.len()); + test.expected_max = + new_null_array(test.input.data_type(), test.expected_min.len()); + test.run() + } + + #[test] + #[should_panic( + expected = "Inconsistent types in ScalarValue::iter_to_array. Expected Utf8, got Binary(NULL)" + )] + // Due to https://github.com/apache/arrow-datafusion/issues/8295 + fn roundtrip_binary() { + Test { + input: Arc::new(BinaryArray::from_opt_vec(vec![ + // row group 1 + Some(b"A"), + None, + Some(b"Q"), + // row group 2 + Some(b"ZZ"), + Some(b"AA"), + None, + // row group 3 + None, + None, + None, + ])), + expected_min: Arc::new(BinaryArray::from_opt_vec(vec![ + Some(b"A"), + Some(b"AA"), + None, + ])), + expected_max: Arc::new(BinaryArray::from_opt_vec(vec![ + Some(b"Q"), + Some(b"ZZ"), + None, + ])), + } + .run() + } + + #[test] + fn struct_and_non_struct() { + // Ensures that statistics for an array that appears *after* a struct + // array are not wrong + let struct_col = struct_array(vec![ + // row group 1 + (Some(true), Some(1)), + (None, None), + (Some(true), Some(3)), + ]); + let int_col = i32_array([Some(100), Some(200), Some(300)]); + let expected_min = i32_array([Some(100)]); + let expected_max = i32_array(vec![Some(300)]); + + // use a name that shadows a name in the struct column + match struct_col.data_type() { + DataType::Struct(fields) => { + assert_eq!(fields.get(1).unwrap().name(), "int_col") + } + _ => panic!("unexpected data type for struct column"), + }; + + let input_batch = RecordBatch::try_from_iter([ + ("struct_col", struct_col), + ("int_col", int_col), + ]) + .unwrap(); + + let schema = input_batch.schema(); + + let metadata = parquet_metadata(schema.clone(), input_batch); + let parquet_schema = metadata.file_metadata().schema_descr(); + + // read the int_col statistics + let (idx, _) = parquet_column(parquet_schema, &schema, "int_col").unwrap(); + assert_eq!(idx, 2); + + let row_groups = metadata.row_groups(); + let iter = row_groups.iter().map(|x| x.column(idx).statistics()); + + let min = min_statistics(&DataType::Int32, iter.clone()).unwrap(); + assert_eq!( + &min, + &expected_min, + "Min. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + + let max = max_statistics(&DataType::Int32, iter).unwrap(); + assert_eq!( + &max, + &expected_max, + "Max. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + } + + #[test] + fn nan_in_stats() { + // /parquet-testing/data/nan_in_stats.parquet + // row_groups: 1 + // "x": Double({min: Some(1.0), max: Some(NaN), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + + TestFile::new("nan_in_stats.parquet") + .with_column(ExpectedColumn { + name: "x", + expected_min: Arc::new(Float64Array::from(vec![Some(1.0)])), + expected_max: Arc::new(Float64Array::from(vec![Some(f64::NAN)])), + }) + .run(); + } + + #[test] + fn alltypes_plain() { + // /parquet-testing/data/datapage_v1-snappy-compressed-checksum.parquet + // row_groups: 1 + // (has no statistics) + TestFile::new("alltypes_plain.parquet") + // No column statistics should be read as NULL, but with the right type + .with_column(ExpectedColumn { + name: "id", + expected_min: i32_array([None]), + expected_max: i32_array([None]), + }) + .with_column(ExpectedColumn { + name: "bool_col", + expected_min: bool_array([None]), + expected_max: bool_array([None]), + }) + .run(); + } + + #[test] + fn alltypes_tiny_pages() { + // /parquet-testing/data/alltypes_tiny_pages.parquet + // row_groups: 1 + // "id": Int32({min: Some(0), max: Some(7299), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "bool_col": Boolean({min: Some(false), max: Some(true), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "tinyint_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "smallint_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "int_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "bigint_col": Int64({min: Some(0), max: Some(90), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "float_col": Float({min: Some(0.0), max: Some(9.9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "double_col": Double({min: Some(0.0), max: Some(90.89999999999999), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "date_string_col": ByteArray({min: Some(ByteArray { data: "01/01/09" }), max: Some(ByteArray { data: "12/31/10" }), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "string_col": ByteArray({min: Some(ByteArray { data: "0" }), max: Some(ByteArray { data: "9" }), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "timestamp_col": Int96({min: None, max: None, distinct_count: None, null_count: 0, min_max_deprecated: true, min_max_backwards_compatible: true}) + // "year": Int32({min: Some(2009), max: Some(2010), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "month": Int32({min: Some(1), max: Some(12), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + TestFile::new("alltypes_tiny_pages.parquet") + .with_column(ExpectedColumn { + name: "id", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(7299)]), + }) + .with_column(ExpectedColumn { + name: "bool_col", + expected_min: bool_array([Some(false)]), + expected_max: bool_array([Some(true)]), + }) + .with_column(ExpectedColumn { + name: "tinyint_col", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(9)]), + }) + .with_column(ExpectedColumn { + name: "smallint_col", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(9)]), + }) + .with_column(ExpectedColumn { + name: "int_col", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(9)]), + }) + .with_column(ExpectedColumn { + name: "bigint_col", + expected_min: i64_array([Some(0)]), + expected_max: i64_array([Some(90)]), + }) + .with_column(ExpectedColumn { + name: "float_col", + expected_min: f32_array([Some(0.0)]), + expected_max: f32_array([Some(9.9)]), + }) + .with_column(ExpectedColumn { + name: "double_col", + expected_min: f64_array([Some(0.0)]), + expected_max: f64_array([Some(90.89999999999999)]), + }) + .with_column(ExpectedColumn { + name: "date_string_col", + expected_min: utf8_array([Some("01/01/09")]), + expected_max: utf8_array([Some("12/31/10")]), + }) + .with_column(ExpectedColumn { + name: "string_col", + expected_min: utf8_array([Some("0")]), + expected_max: utf8_array([Some("9")]), + }) + // File has no min/max for timestamp_col + .with_column(ExpectedColumn { + name: "timestamp_col", + expected_min: timestamp_array([None]), + expected_max: timestamp_array([None]), + }) + .with_column(ExpectedColumn { + name: "year", + expected_min: i32_array([Some(2009)]), + expected_max: i32_array([Some(2010)]), + }) + .with_column(ExpectedColumn { + name: "month", + expected_min: i32_array([Some(1)]), + expected_max: i32_array([Some(12)]), + }) + .run(); + } + + #[test] + fn fixed_length_decimal_legacy() { + // /parquet-testing/data/fixed_length_decimal_legacy.parquet + // row_groups: 1 + // "value": FixedLenByteArray({min: Some(FixedLenByteArray(ByteArray { data: Some(ByteBufferPtr { data: b"\0\0\0\0\0\xc8" }) })), max: Some(FixedLenByteArray(ByteArray { data: "\0\0\0\0\t`" })), distinct_count: None, null_count: 0, min_max_deprecated: true, min_max_backwards_compatible: true}) + + TestFile::new("fixed_length_decimal_legacy.parquet") + .with_column(ExpectedColumn { + name: "value", + expected_min: Arc::new( + Decimal128Array::from(vec![Some(200)]) + .with_precision_and_scale(13, 2) + .unwrap(), + ), + expected_max: Arc::new( + Decimal128Array::from(vec![Some(2400)]) + .with_precision_and_scale(13, 2) + .unwrap(), + ), + }) + .run(); + } + + const ROWS_PER_ROW_GROUP: usize = 3; + + /// Writes the input batch into a parquet file, with every every three rows as + /// their own row group, and compares the min/maxes to the expected values + struct Test { + input: ArrayRef, + expected_min: ArrayRef, + expected_max: ArrayRef, + } + + impl Test { + fn run(self) { + let Self { + input, + expected_min, + expected_max, + } = self; + + let input_batch = RecordBatch::try_from_iter([("c1", input)]).unwrap(); + + let schema = input_batch.schema(); + + let metadata = parquet_metadata(schema.clone(), input_batch); + let parquet_schema = metadata.file_metadata().schema_descr(); + + let row_groups = metadata.row_groups(); + + for field in schema.fields() { + if field.data_type().is_nested() { + let lookup = parquet_column(parquet_schema, &schema, field.name()); + assert_eq!(lookup, None); + continue; + } + + let (idx, f) = + parquet_column(parquet_schema, &schema, field.name()).unwrap(); + assert_eq!(f, field); + + let iter = row_groups.iter().map(|x| x.column(idx).statistics()); + let min = min_statistics(f.data_type(), iter.clone()).unwrap(); + assert_eq!( + &min, + &expected_min, + "Min. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + + let max = max_statistics(f.data_type(), iter).unwrap(); + assert_eq!( + &max, + &expected_max, + "Max. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + } + } + } + + /// Write the specified batches out as parquet and return the metadata + fn parquet_metadata(schema: SchemaRef, batch: RecordBatch) -> Arc { + let props = WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::Chunk) + .set_max_row_group_size(ROWS_PER_ROW_GROUP) + .build(); + + let mut buffer = Vec::new(); + let mut writer = ArrowWriter::try_new(&mut buffer, schema, Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + let reader = ArrowReaderBuilder::try_new(Bytes::from(buffer)).unwrap(); + reader.metadata().clone() + } + + /// Formats the statistics nicely for display + struct DisplayStats<'a>(&'a [RowGroupMetaData]); + impl<'a> std::fmt::Display for DisplayStats<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let row_groups = self.0; + writeln!(f, " row_groups: {}", row_groups.len())?; + for rg in row_groups { + for col in rg.columns() { + if let Some(statistics) = col.statistics() { + writeln!(f, " {}: {:?}", col.column_path(), statistics)?; + } + } + } + Ok(()) + } + } + + struct ExpectedColumn { + name: &'static str, + expected_min: ArrayRef, + expected_max: ArrayRef, + } + + /// Reads statistics out of the specified, and compares them to the expected values + struct TestFile { + file_name: &'static str, + expected_columns: Vec, + } + + impl TestFile { + fn new(file_name: &'static str) -> Self { + Self { + file_name, + expected_columns: Vec::new(), + } + } + + fn with_column(mut self, column: ExpectedColumn) -> Self { + self.expected_columns.push(column); + self + } + + /// Reads the specified parquet file and validates that the exepcted min/max + /// values for the specified columns are as expected. + fn run(self) { + let path = PathBuf::from(parquet_test_data()).join(self.file_name); + let file = std::fs::File::open(path).unwrap(); + let reader = ArrowReaderBuilder::try_new(file).unwrap(); + let arrow_schema = reader.schema(); + let metadata = reader.metadata(); + let row_groups = metadata.row_groups(); + let parquet_schema = metadata.file_metadata().schema_descr(); + + for expected_column in self.expected_columns { + let ExpectedColumn { + name, + expected_min, + expected_max, + } = expected_column; + + let (idx, field) = + parquet_column(parquet_schema, arrow_schema, name).unwrap(); + + let iter = row_groups.iter().map(|x| x.column(idx).statistics()); + let actual_min = min_statistics(field.data_type(), iter.clone()).unwrap(); + assert_eq!(&expected_min, &actual_min, "column {name}"); + + let actual_max = max_statistics(field.data_type(), iter).unwrap(); + assert_eq!(&expected_max, &actual_max, "column {name}"); + } + } + } + + fn bool_array(input: impl IntoIterator>) -> ArrayRef { + let array: BooleanArray = input.into_iter().collect(); + Arc::new(array) + } + + fn i32_array(input: impl IntoIterator>) -> ArrayRef { + let array: Int32Array = input.into_iter().collect(); + Arc::new(array) + } + + fn i64_array(input: impl IntoIterator>) -> ArrayRef { + let array: Int64Array = input.into_iter().collect(); + Arc::new(array) + } + + fn f32_array(input: impl IntoIterator>) -> ArrayRef { + let array: Float32Array = input.into_iter().collect(); + Arc::new(array) + } + + fn f64_array(input: impl IntoIterator>) -> ArrayRef { + let array: Float64Array = input.into_iter().collect(); + Arc::new(array) + } + + fn timestamp_array(input: impl IntoIterator>) -> ArrayRef { + let array: TimestampNanosecondArray = input.into_iter().collect(); + Arc::new(array) + } + + fn utf8_array<'a>(input: impl IntoIterator>) -> ArrayRef { + let array: StringArray = input + .into_iter() + .map(|s| s.map(|s| s.to_string())) + .collect(); + Arc::new(array) + } + + // returns a struct array with columns "bool_col" and "int_col" with the specified values + fn struct_array(input: Vec<(Option, Option)>) -> ArrayRef { + let boolean: BooleanArray = input.iter().map(|(b, _i)| b).collect(); + let int: Int32Array = input.iter().map(|(_b, i)| i).collect(); + + let nullable = true; + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("bool_col", DataType::Boolean, nullable)), + Arc::new(boolean) as ArrayRef, + ), + ( + Arc::new(Field::new("int_col", DataType::Int32, nullable)), + Arc::new(int) as ArrayRef, + ), + ]); + Arc::new(struct_array) + } +} diff --git a/datafusion/core/src/datasource/provider.rs b/datafusion/core/src/datasource/provider.rs index 5ebcc45b572b..c1cee849fe5c 100644 --- a/datafusion/core/src/datasource/provider.rs +++ b/datafusion/core/src/datasource/provider.rs @@ -26,6 +26,8 @@ use datafusion_expr::{CreateExternalTable, LogicalPlan}; pub use datafusion_expr::{TableProviderFilterPushDown, TableType}; use crate::arrow::datatypes::SchemaRef; +use crate::datasource::listing_table_factory::ListingTableFactory; +use crate::datasource::stream::StreamTableFactory; use crate::error::Result; use crate::execution::context::SessionState; use crate::logical_expr::Expr; @@ -42,6 +44,11 @@ pub trait TableProvider: Sync + Send { fn schema(&self) -> SchemaRef; /// Get a reference to the constraints of the table. + /// Returns: + /// - `None` for tables that do not support constraints. + /// - `Some(&Constraints)` for tables supporting constraints. + /// Therefore, a `Some(&Constraints::empty())` return value indicates that + /// this table supports constraints, but there are no constraints. fn constraints(&self) -> Option<&Constraints> { None } @@ -54,24 +61,96 @@ pub trait TableProvider: Sync + Send { None } - /// Get the Logical Plan of this table, if available. + /// Get the [`LogicalPlan`] of this table, if available fn get_logical_plan(&self) -> Option<&LogicalPlan> { None } - /// Create an ExecutionPlan that will scan the table. - /// The table provider will be usually responsible of grouping - /// the source data into partitions that can be efficiently - /// parallelized or distributed. + /// Get the default value for a column, if available. + fn get_column_default(&self, _column: &str) -> Option<&Expr> { + None + } + + /// Create an [`ExecutionPlan`] for scanning the table with optionally + /// specified `projection`, `filter` and `limit`, described below. + /// + /// The `ExecutionPlan` is responsible scanning the datasource's + /// partitions in a streaming, parallelized fashion. + /// + /// # Projection + /// + /// If specified, only a subset of columns should be returned, in the order + /// specified. The projection is a set of indexes of the fields in + /// [`Self::schema`]. + /// + /// DataFusion provides the projection to scan only the columns actually + /// used in the query to improve performance, an optimization called + /// "Projection Pushdown". Some datasources, such as Parquet, can use this + /// information to go significantly faster when only a subset of columns is + /// required. + /// + /// # Filters + /// + /// A list of boolean filter [`Expr`]s to evaluate *during* the scan, in the + /// manner specified by [`Self::supports_filters_pushdown`]. Only rows for + /// which *all* of the `Expr`s evaluate to `true` must be returned (aka the + /// expressions are `AND`ed together). + /// + /// DataFusion pushes filtering into the scans whenever possible + /// ("Projection Pushdown"), and depending on the format and the + /// implementation of the format, evaluating the predicate during the scan + /// can increase performance significantly. + /// + /// ## Note: Some columns may appear *only* in Filters + /// + /// In certain cases, a query may only use a certain column in a Filter that + /// has been completely pushed down to the scan. In this case, the + /// projection will not contain all the columns found in the filter + /// expressions. + /// + /// For example, given the query `SELECT t.a FROM t WHERE t.b > 5`, + /// + /// ```text + /// ┌────────────────────┐ + /// │ Projection(t.a) │ + /// └────────────────────┘ + /// ▲ + /// │ + /// │ + /// ┌────────────────────┐ Filter ┌────────────────────┐ Projection ┌────────────────────┐ + /// │ Filter(t.b > 5) │────Pushdown──▶ │ Projection(t.a) │ ───Pushdown───▶ │ Projection(t.a) │ + /// └────────────────────┘ └────────────────────┘ └────────────────────┘ + /// ▲ ▲ ▲ + /// │ │ │ + /// │ │ ┌────────────────────┐ + /// ┌────────────────────┐ ┌────────────────────┐ │ Scan │ + /// │ Scan │ │ Scan │ │ filter=(t.b > 5) │ + /// └────────────────────┘ │ filter=(t.b > 5) │ │ projection=(t.a) │ + /// └────────────────────┘ └────────────────────┘ + /// + /// Initial Plan If `TableProviderFilterPushDown` Projection pushdown notes that + /// returns true, filter pushdown the scan only needs t.a + /// pushes the filter into the scan + /// BUT internally evaluating the + /// predicate still requires t.b + /// ``` + /// + /// # Limit + /// + /// If `limit` is specified, must only produce *at least* this many rows, + /// (though it may return more). Like Projection Pushdown and Filter + /// Pushdown, DataFusion pushes `LIMIT`s as far down in the plan as + /// possible, called "Limit Pushdown" as some sources can use this + /// information to improve their performance. Note that if there are any + /// Inexact filters pushed down, the LIMIT cannot be pushed down. This is + /// because inexact filters do not guarentee that every filtered row is + /// removed, so applying the limit could lead to too few rows being available + /// to return as a final result. async fn scan( &self, state: &SessionState, projection: Option<&Vec>, filters: &[Expr], - // limit can be used to reduce the amount scanned - // from the datasource as a performance optimization. - // If set, it contains the amount of rows needed by the `LogicalPlan`, - // The datasource should return *at least* this number of rows if available. limit: Option, ) -> Result>; @@ -146,3 +225,41 @@ pub trait TableProviderFactory: Sync + Send { cmd: &CreateExternalTable, ) -> Result>; } + +/// The default [`TableProviderFactory`] +/// +/// If [`CreateExternalTable`] is unbounded calls [`StreamTableFactory::create`], +/// otherwise calls [`ListingTableFactory::create`] +#[derive(Debug, Default)] +pub struct DefaultTableFactory { + stream: StreamTableFactory, + listing: ListingTableFactory, +} + +impl DefaultTableFactory { + /// Creates a new [`DefaultTableFactory`] + pub fn new() -> Self { + Self::default() + } +} + +#[async_trait] +impl TableProviderFactory for DefaultTableFactory { + async fn create( + &self, + state: &SessionState, + cmd: &CreateExternalTable, + ) -> Result> { + let mut unbounded = cmd.unbounded; + for (k, v) in &cmd.options { + if k.eq_ignore_ascii_case("unbounded") && v.eq_ignore_ascii_case("true") { + unbounded = true + } + } + + match unbounded { + true => self.stream.create(state, cmd).await, + false => self.listing.create(state, cmd).await, + } + } +} diff --git a/datafusion/core/src/datasource/statistics.rs b/datafusion/core/src/datasource/statistics.rs index 1b6a03e15c02..695e139517cf 100644 --- a/datafusion/core/src/datasource/statistics.rs +++ b/datafusion/core/src/datasource/statistics.rs @@ -15,14 +15,18 @@ // specific language governing permissions and limitations // under the License. +use super::listing::PartitionedFile; use crate::arrow::datatypes::{Schema, SchemaRef}; use crate::error::Result; use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; use crate::physical_plan::{Accumulator, ColumnStatistics, Statistics}; -use futures::Stream; -use futures::StreamExt; -use super::listing::PartitionedFile; +use datafusion_common::stats::Precision; +use datafusion_common::ScalarValue; + +use futures::{Stream, StreamExt}; +use itertools::izip; +use itertools::multiunzip; /// Get all files as well as the file level summary statistics (no statistic for partition columns). /// If the optional `limit` is provided, includes only sufficient files. @@ -33,101 +37,109 @@ pub async fn get_statistics_with_limit( limit: Option, ) -> Result<(Vec, Statistics)> { let mut result_files = vec![]; + // These statistics can be calculated as long as at least one file provides + // useful information. If none of the files provides any information, then + // they will end up having `Precision::Absent` values. Throughout calculations, + // missing values will be imputed as: + // - zero for summations, and + // - neutral element for extreme points. + let size = file_schema.fields().len(); + let mut null_counts: Vec> = vec![Precision::Absent; size]; + let mut max_values: Vec> = vec![Precision::Absent; size]; + let mut min_values: Vec> = vec![Precision::Absent; size]; + let mut num_rows = Precision::::Absent; + let mut total_byte_size = Precision::::Absent; - let mut null_counts = vec![0; file_schema.fields().len()]; - let mut has_statistics = false; - let (mut max_values, mut min_values) = create_max_min_accs(&file_schema); - - let mut is_exact = true; - - // The number of rows and the total byte size can be calculated as long as - // at least one file has them. If none of the files provide them, then they - // will be omitted from the statistics. The missing values will be counted - // as zero. - let mut num_rows = None; - let mut total_byte_size = None; - - // fusing the stream allows us to call next safely even once it is finished + // Fusing the stream allows us to call next safely even once it is finished. let mut all_files = Box::pin(all_files.fuse()); - while let Some(res) = all_files.next().await { - let (file, file_stats) = res?; + + if let Some(first_file) = all_files.next().await { + let (file, file_stats) = first_file?; result_files.push(file); - is_exact &= file_stats.is_exact; - num_rows = if let Some(num_rows) = num_rows { - Some(num_rows + file_stats.num_rows.unwrap_or(0)) - } else { - file_stats.num_rows - }; - total_byte_size = if let Some(total_byte_size) = total_byte_size { - Some(total_byte_size + file_stats.total_byte_size.unwrap_or(0)) - } else { - file_stats.total_byte_size - }; - if let Some(vec) = &file_stats.column_statistics { - has_statistics = true; - for (i, cs) in vec.iter().enumerate() { - null_counts[i] += cs.null_count.unwrap_or(0); - - if let Some(max_value) = &mut max_values[i] { - if let Some(file_max) = cs.max_value.clone() { - match max_value.update_batch(&[file_max.to_array()]) { - Ok(_) => {} - Err(_) => { - max_values[i] = None; - } - } - } else { - max_values[i] = None; - } - } - if let Some(min_value) = &mut min_values[i] { - if let Some(file_min) = cs.min_value.clone() { - match min_value.update_batch(&[file_min.to_array()]) { - Ok(_) => {} - Err(_) => { - min_values[i] = None; - } - } - } else { - min_values[i] = None; - } - } - } + // First file, we set them directly from the file statistics. + num_rows = file_stats.num_rows; + total_byte_size = file_stats.total_byte_size; + for (index, file_column) in file_stats.column_statistics.into_iter().enumerate() { + null_counts[index] = file_column.null_count; + max_values[index] = file_column.max_value; + min_values[index] = file_column.min_value; } // If the number of rows exceeds the limit, we can stop processing // files. This only applies when we know the number of rows. It also // currently ignores tables that have no statistics regarding the // number of rows. - if num_rows.unwrap_or(usize::MIN) > limit.unwrap_or(usize::MAX) { - break; - } - } - // if we still have files in the stream, it means that the limit kicked - // in and that the statistic could have been different if we processed - // the files in a different order. - if all_files.next().await.is_some() { - is_exact = false; - } + let conservative_num_rows = match num_rows { + Precision::Exact(nr) => nr, + _ => usize::MIN, + }; + if conservative_num_rows <= limit.unwrap_or(usize::MAX) { + while let Some(current) = all_files.next().await { + let (file, file_stats) = current?; + result_files.push(file); + + // We accumulate the number of rows, total byte size and null + // counts across all the files in question. If any file does not + // provide any information or provides an inexact value, we demote + // the statistic precision to inexact. + num_rows = add_row_stats(file_stats.num_rows, num_rows); + + total_byte_size = + add_row_stats(file_stats.total_byte_size, total_byte_size); - let column_stats = if has_statistics { - Some(get_col_stats( - &file_schema, - null_counts, - &mut max_values, - &mut min_values, - )) - } else { - None + (null_counts, max_values, min_values) = multiunzip( + izip!( + file_stats.column_statistics.into_iter(), + null_counts.into_iter(), + max_values.into_iter(), + min_values.into_iter() + ) + .map( + |( + ColumnStatistics { + null_count: file_nc, + max_value: file_max, + min_value: file_min, + distinct_count: _, + }, + null_count, + max_value, + min_value, + )| { + ( + add_row_stats(file_nc, null_count), + set_max_if_greater(file_max, max_value), + set_min_if_lesser(file_min, min_value), + ) + }, + ), + ); + + // If the number of rows exceeds the limit, we can stop processing + // files. This only applies when we know the number of rows. It also + // currently ignores tables that have no statistics regarding the + // number of rows. + if num_rows.get_value().unwrap_or(&usize::MIN) + > &limit.unwrap_or(usize::MAX) + { + break; + } + } + } }; - let statistics = Statistics { + let mut statistics = Statistics { num_rows, total_byte_size, - column_statistics: column_stats, - is_exact, + column_statistics: get_col_stats_vec(null_counts, max_values, min_values), }; + if all_files.next().await.is_some() { + // If we still have files in the stream, it means that the limit kicked + // in, and the statistic could have been different had we processed the + // files in a different order. + statistics = statistics.into_inexact() + } Ok((result_files, statistics)) } @@ -139,18 +151,44 @@ pub(crate) fn create_max_min_accs( .fields() .iter() .map(|field| MaxAccumulator::try_new(field.data_type()).ok()) - .collect::>(); + .collect(); let min_values: Vec> = schema .fields() .iter() .map(|field| MinAccumulator::try_new(field.data_type()).ok()) - .collect::>(); + .collect(); (max_values, min_values) } +fn add_row_stats( + file_num_rows: Precision, + num_rows: Precision, +) -> Precision { + match (file_num_rows, &num_rows) { + (Precision::Absent, _) => num_rows.to_inexact(), + (lhs, Precision::Absent) => lhs.to_inexact(), + (lhs, rhs) => lhs.add(rhs), + } +} + +pub(crate) fn get_col_stats_vec( + null_counts: Vec>, + max_values: Vec>, + min_values: Vec>, +) -> Vec { + izip!(null_counts, max_values, min_values) + .map(|(null_count, max_value, min_value)| ColumnStatistics { + null_count, + max_value, + min_value, + distinct_count: Precision::Absent, + }) + .collect() +} + pub(crate) fn get_col_stats( schema: &Schema, - null_counts: Vec, + null_counts: Vec>, max_values: &mut [Option], min_values: &mut [Option], ) -> Vec { @@ -165,11 +203,57 @@ pub(crate) fn get_col_stats( None => None, }; ColumnStatistics { - null_count: Some(null_counts[i]), - max_value, - min_value, - distinct_count: None, + null_count: null_counts[i].clone(), + max_value: max_value.map(Precision::Exact).unwrap_or(Precision::Absent), + min_value: min_value.map(Precision::Exact).unwrap_or(Precision::Absent), + distinct_count: Precision::Absent, } }) .collect() } + +/// If the given value is numerically greater than the original maximum value, +/// return the new maximum value with appropriate exactness information. +fn set_max_if_greater( + max_nominee: Precision, + max_values: Precision, +) -> Precision { + match (&max_values, &max_nominee) { + (Precision::Exact(val1), Precision::Exact(val2)) if val1 < val2 => max_nominee, + (Precision::Exact(val1), Precision::Inexact(val2)) + | (Precision::Inexact(val1), Precision::Inexact(val2)) + | (Precision::Inexact(val1), Precision::Exact(val2)) + if val1 < val2 => + { + max_nominee.to_inexact() + } + (Precision::Exact(_), Precision::Absent) => max_values.to_inexact(), + (Precision::Absent, Precision::Exact(_)) => max_nominee.to_inexact(), + (Precision::Absent, Precision::Inexact(_)) => max_nominee, + (Precision::Absent, Precision::Absent) => Precision::Absent, + _ => max_values, + } +} + +/// If the given value is numerically lesser than the original minimum value, +/// return the new minimum value with appropriate exactness information. +fn set_min_if_lesser( + min_nominee: Precision, + min_values: Precision, +) -> Precision { + match (&min_values, &min_nominee) { + (Precision::Exact(val1), Precision::Exact(val2)) if val1 > val2 => min_nominee, + (Precision::Exact(val1), Precision::Inexact(val2)) + | (Precision::Inexact(val1), Precision::Inexact(val2)) + | (Precision::Inexact(val1), Precision::Exact(val2)) + if val1 > val2 => + { + min_nominee.to_inexact() + } + (Precision::Exact(_), Precision::Absent) => min_values.to_inexact(), + (Precision::Absent, Precision::Exact(_)) => min_nominee.to_inexact(), + (Precision::Absent, Precision::Inexact(_)) => min_nominee, + (Precision::Absent, Precision::Absent) => Precision::Absent, + _ => min_values, + } +} diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs new file mode 100644 index 000000000000..830cd7a07e46 --- /dev/null +++ b/datafusion/core/src/datasource/stream.rs @@ -0,0 +1,365 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! TableProvider for stream sources, such as FIFO files + +use std::any::Any; +use std::fmt::Formatter; +use std::fs::{File, OpenOptions}; +use std::io::BufReader; +use std::path::PathBuf; +use std::str::FromStr; +use std::sync::Arc; + +use arrow_array::{RecordBatch, RecordBatchReader, RecordBatchWriter}; +use arrow_schema::SchemaRef; +use async_trait::async_trait; +use futures::StreamExt; +use tokio::task::spawn_blocking; + +use datafusion_common::{plan_err, Constraints, DataFusionError, Result}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use datafusion_physical_plan::common::AbortOnDropSingle; +use datafusion_physical_plan::insert::{DataSink, FileSinkExec}; +use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; +use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; + +use crate::datasource::provider::TableProviderFactory; +use crate::datasource::{create_ordering, TableProvider}; +use crate::execution::context::SessionState; + +/// A [`TableProviderFactory`] for [`StreamTable`] +#[derive(Debug, Default)] +pub struct StreamTableFactory {} + +#[async_trait] +impl TableProviderFactory for StreamTableFactory { + async fn create( + &self, + state: &SessionState, + cmd: &CreateExternalTable, + ) -> Result> { + let schema: SchemaRef = Arc::new(cmd.schema.as_ref().into()); + let location = cmd.location.clone(); + let encoding = cmd.file_type.parse()?; + + let config = StreamConfig::new_file(schema, location.into()) + .with_encoding(encoding) + .with_order(cmd.order_exprs.clone()) + .with_header(cmd.has_header) + .with_batch_size(state.config().batch_size()) + .with_constraints(cmd.constraints.clone()); + + Ok(Arc::new(StreamTable(Arc::new(config)))) + } +} + +/// The data encoding for [`StreamTable`] +#[derive(Debug, Clone)] +pub enum StreamEncoding { + /// CSV records + Csv, + /// Newline-delimited JSON records + Json, +} + +impl FromStr for StreamEncoding { + type Err = DataFusionError; + + fn from_str(s: &str) -> std::result::Result { + match s.to_ascii_lowercase().as_str() { + "csv" => Ok(Self::Csv), + "json" => Ok(Self::Json), + _ => plan_err!("Unrecognised StreamEncoding {}", s), + } + } +} + +/// The configuration for a [`StreamTable`] +#[derive(Debug)] +pub struct StreamConfig { + schema: SchemaRef, + location: PathBuf, + batch_size: usize, + encoding: StreamEncoding, + header: bool, + order: Vec>, + constraints: Constraints, +} + +impl StreamConfig { + /// Stream data from the file at `location` + /// + /// * Data will be read sequentially from the provided `location` + /// * New data will be appended to the end of the file + /// + /// The encoding can be configured with [`Self::with_encoding`] and + /// defaults to [`StreamEncoding::Csv`] + pub fn new_file(schema: SchemaRef, location: PathBuf) -> Self { + Self { + schema, + location, + batch_size: 1024, + encoding: StreamEncoding::Csv, + order: vec![], + header: false, + constraints: Constraints::empty(), + } + } + + /// Specify a sort order for the stream + pub fn with_order(mut self, order: Vec>) -> Self { + self.order = order; + self + } + + /// Specify the batch size + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// Specify whether the file has a header (only applicable for [`StreamEncoding::Csv`]) + pub fn with_header(mut self, header: bool) -> Self { + self.header = header; + self + } + + /// Specify an encoding for the stream + pub fn with_encoding(mut self, encoding: StreamEncoding) -> Self { + self.encoding = encoding; + self + } + + /// Assign constraints + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self + } + + fn reader(&self) -> Result> { + let file = File::open(&self.location)?; + let schema = self.schema.clone(); + match &self.encoding { + StreamEncoding::Csv => { + let reader = arrow::csv::ReaderBuilder::new(schema) + .with_header(self.header) + .with_batch_size(self.batch_size) + .build(file)?; + + Ok(Box::new(reader)) + } + StreamEncoding::Json => { + let reader = arrow::json::ReaderBuilder::new(schema) + .with_batch_size(self.batch_size) + .build(BufReader::new(file))?; + + Ok(Box::new(reader)) + } + } + } + + fn writer(&self) -> Result> { + match &self.encoding { + StreamEncoding::Csv => { + let header = self.header && !self.location.exists(); + let file = OpenOptions::new() + .create(true) + .append(true) + .open(&self.location)?; + let writer = arrow::csv::WriterBuilder::new() + .with_header(header) + .build(file); + + Ok(Box::new(writer)) + } + StreamEncoding::Json => { + let file = OpenOptions::new() + .create(true) + .append(true) + .open(&self.location)?; + Ok(Box::new(arrow::json::LineDelimitedWriter::new(file))) + } + } + } +} + +/// A [`TableProvider`] for an unbounded stream source +/// +/// Currently only reading from / appending to a single file in-place is supported, but +/// other stream sources and sinks may be added in future. +/// +/// Applications looking to read/write datasets comprising multiple files, e.g. [Hadoop]-style +/// data stored in object storage, should instead consider [`ListingTable`]. +/// +/// [Hadoop]: https://hadoop.apache.org/ +/// [`ListingTable`]: crate::datasource::listing::ListingTable +pub struct StreamTable(Arc); + +impl StreamTable { + /// Create a new [`StreamTable`] for the given [`StreamConfig`] + pub fn new(config: Arc) -> Self { + Self(config) + } +} + +#[async_trait] +impl TableProvider for StreamTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.0.schema.clone() + } + + fn constraints(&self) -> Option<&Constraints> { + Some(&self.0.constraints) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let projected_schema = match projection { + Some(p) => { + let projected = self.0.schema.project(p)?; + create_ordering(&projected, &self.0.order)? + } + None => create_ordering(self.0.schema.as_ref(), &self.0.order)?, + }; + + Ok(Arc::new(StreamingTableExec::try_new( + self.0.schema.clone(), + vec![Arc::new(StreamRead(self.0.clone())) as _], + projection, + projected_schema, + true, + )?)) + } + + async fn insert_into( + &self, + _state: &SessionState, + input: Arc, + _overwrite: bool, + ) -> Result> { + let ordering = match self.0.order.first() { + Some(x) => { + let schema = self.0.schema.as_ref(); + let orders = create_ordering(schema, std::slice::from_ref(x))?; + let ordering = orders.into_iter().next().unwrap(); + Some(ordering.into_iter().map(Into::into).collect()) + } + None => None, + }; + + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(StreamWrite(self.0.clone())), + self.0.schema.clone(), + ordering, + ))) + } +} + +struct StreamRead(Arc); + +impl PartitionStream for StreamRead { + fn schema(&self) -> &SchemaRef { + &self.0.schema + } + + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + let config = self.0.clone(); + let schema = self.0.schema.clone(); + let mut builder = RecordBatchReceiverStreamBuilder::new(schema, 2); + let tx = builder.tx(); + builder.spawn_blocking(move || { + let reader = config.reader()?; + for b in reader { + if tx.blocking_send(b.map_err(Into::into)).is_err() { + break; + } + } + Ok(()) + }); + builder.build() + } +} + +#[derive(Debug)] +struct StreamWrite(Arc); + +impl DisplayAs for StreamWrite { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + f.debug_struct("StreamWrite") + .field("location", &self.0.location) + .field("batch_size", &self.0.batch_size) + .field("encoding", &self.0.encoding) + .field("header", &self.0.header) + .finish_non_exhaustive() + } +} + +#[async_trait] +impl DataSink for StreamWrite { + fn as_any(&self) -> &dyn Any { + self + } + + fn metrics(&self) -> Option { + None + } + + async fn write_all( + &self, + mut data: SendableRecordBatchStream, + _context: &Arc, + ) -> Result { + let config = self.0.clone(); + let (sender, mut receiver) = tokio::sync::mpsc::channel::(2); + // Note: FIFO Files support poll so this could use AsyncFd + let write = AbortOnDropSingle::new(spawn_blocking(move || { + let mut count = 0_u64; + let mut writer = config.writer()?; + while let Some(batch) = receiver.blocking_recv() { + count += batch.num_rows() as u64; + writer.write(&batch)?; + } + Ok(count) + })); + + while let Some(b) = data.next().await.transpose()? { + if sender.send(b).await.is_err() { + break; + } + } + drop(sender); + write.await.unwrap() + } +} diff --git a/datafusion/core/src/datasource/view.rs b/datafusion/core/src/datasource/view.rs index d58284d1bac5..85fb8939886c 100644 --- a/datafusion/core/src/datasource/view.rs +++ b/datafusion/core/src/datasource/view.rs @@ -159,7 +159,7 @@ mod tests { #[tokio::test] async fn issue_3242() -> Result<()> { // regression test for https://github.com/apache/arrow-datafusion/pull/3242 - let session_ctx = SessionContext::with_config( + let session_ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -199,7 +199,7 @@ mod tests { #[tokio::test] async fn query_view() -> Result<()> { - let session_ctx = SessionContext::with_config( + let session_ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -237,7 +237,7 @@ mod tests { #[tokio::test] async fn query_view_with_alias() -> Result<()> { - let session_ctx = SessionContext::with_config(SessionConfig::new()); + let session_ctx = SessionContext::new_with_config(SessionConfig::new()); session_ctx .sql("CREATE TABLE abc AS VALUES (1,2,3), (4,5,6)") @@ -270,7 +270,7 @@ mod tests { #[tokio::test] async fn query_view_with_inline_alias() -> Result<()> { - let session_ctx = SessionContext::with_config(SessionConfig::new()); + let session_ctx = SessionContext::new_with_config(SessionConfig::new()); session_ctx .sql("CREATE TABLE abc AS VALUES (1,2,3), (4,5,6)") @@ -303,7 +303,7 @@ mod tests { #[tokio::test] async fn query_view_with_projection() -> Result<()> { - let session_ctx = SessionContext::with_config( + let session_ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -341,7 +341,7 @@ mod tests { #[tokio::test] async fn query_view_with_filter() -> Result<()> { - let session_ctx = SessionContext::with_config( + let session_ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -378,7 +378,7 @@ mod tests { #[tokio::test] async fn query_join_views() -> Result<()> { - let session_ctx = SessionContext::with_config( + let session_ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -481,7 +481,7 @@ mod tests { #[tokio::test] async fn create_view_plan() -> Result<()> { - let session_ctx = SessionContext::with_config( + let session_ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -534,7 +534,7 @@ mod tests { #[tokio::test] async fn create_or_replace_view() -> Result<()> { - let session_ctx = SessionContext::with_config( + let session_ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); diff --git a/datafusion/core/src/execution/context/avro.rs b/datafusion/core/src/execution/context/avro.rs new file mode 100644 index 000000000000..d60e79862ef2 --- /dev/null +++ b/datafusion/core/src/execution/context/avro.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use super::super::options::{AvroReadOptions, ReadOptions}; +use super::{DataFilePaths, DataFrame, Result, SessionContext}; + +impl SessionContext { + /// Creates a [`DataFrame`] for reading an Avro data source. + /// + /// For more control such as reading multiple files, you can use + /// [`read_table`](Self::read_table) with a [`super::ListingTable`]. + /// + /// For an example, see [`read_csv`](Self::read_csv) + pub async fn read_avro( + &self, + table_paths: P, + options: AvroReadOptions<'_>, + ) -> Result { + self._read_type(table_paths, options).await + } + + /// Registers an Avro file as a table that can be referenced from + /// SQL statements executed against this context. + pub async fn register_avro( + &self, + name: &str, + table_path: &str, + options: AvroReadOptions<'_>, + ) -> Result<()> { + let listing_options = options.to_listing_options(&self.copied_config()); + + self.register_listing_table( + name, + table_path, + listing_options, + options.schema.map(|s| Arc::new(s.to_owned())), + None, + ) + .await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use async_trait::async_trait; + + // Test for compilation error when calling read_* functions from an #[async_trait] function. + // See https://github.com/apache/arrow-datafusion/issues/1154 + #[async_trait] + trait CallReadTrait { + async fn call_read_avro(&self) -> DataFrame; + } + + struct CallRead {} + + #[async_trait] + impl CallReadTrait for CallRead { + async fn call_read_avro(&self) -> DataFrame { + let ctx = SessionContext::new(); + ctx.read_avro("dummy", AvroReadOptions::default()) + .await + .unwrap() + } + } +} diff --git a/datafusion/core/src/execution/context/csv.rs b/datafusion/core/src/execution/context/csv.rs new file mode 100644 index 000000000000..f3675422c7d5 --- /dev/null +++ b/datafusion/core/src/execution/context/csv.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::datasource::physical_plan::plan_to_csv; + +use super::super::options::{CsvReadOptions, ReadOptions}; +use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; + +impl SessionContext { + /// Creates a [`DataFrame`] for reading a CSV data source. + /// + /// For more control such as reading multiple files, you can use + /// [`read_table`](Self::read_table) with a [`super::ListingTable`]. + /// + /// Example usage is given below: + /// + /// ``` + /// use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// // You can read a single file using `read_csv` + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// // you can also read multiple files: + /// let df = ctx.read_csv(vec!["tests/data/example.csv", "tests/data/example.csv"], CsvReadOptions::new()).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn read_csv( + &self, + table_paths: P, + options: CsvReadOptions<'_>, + ) -> Result { + self._read_type(table_paths, options).await + } + + /// Registers a CSV file as a table which can referenced from SQL + /// statements executed against this context. + pub async fn register_csv( + &self, + name: &str, + table_path: &str, + options: CsvReadOptions<'_>, + ) -> Result<()> { + let listing_options = options.to_listing_options(&self.copied_config()); + + self.register_listing_table( + name, + table_path, + listing_options, + options.schema.map(|s| Arc::new(s.to_owned())), + None, + ) + .await?; + + Ok(()) + } + + /// Executes a query and writes the results to a partitioned CSV file. + pub async fn write_csv( + &self, + plan: Arc, + path: impl AsRef, + ) -> Result<()> { + plan_to_csv(self.task_ctx(), plan, path).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::assert_batches_eq; + use crate::test_util::{plan_and_collect, populate_csv_partitions}; + use async_trait::async_trait; + use tempfile::TempDir; + + #[tokio::test] + async fn query_csv_with_custom_partition_extension() -> Result<()> { + let tmp_dir = TempDir::new()?; + + // The main stipulation of this test: use a file extension that isn't .csv. + let file_extension = ".tst"; + + let ctx = SessionContext::new(); + let schema = populate_csv_partitions(&tmp_dir, 2, file_extension)?; + ctx.register_csv( + "test", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new() + .schema(&schema) + .file_extension(file_extension), + ) + .await?; + let results = + plan_and_collect(&ctx, "SELECT SUM(c1), SUM(c2), COUNT(*) FROM test").await?; + + assert_eq!(results.len(), 1); + let expected = [ + "+--------------+--------------+----------+", + "| SUM(test.c1) | SUM(test.c2) | COUNT(*) |", + "+--------------+--------------+----------+", + "| 10 | 110 | 20 |", + "+--------------+--------------+----------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) + } + + // Test for compilation error when calling read_* functions from an #[async_trait] function. + // See https://github.com/apache/arrow-datafusion/issues/1154 + #[async_trait] + trait CallReadTrait { + async fn call_read_csv(&self) -> DataFrame; + } + + struct CallRead {} + + #[async_trait] + impl CallReadTrait for CallRead { + async fn call_read_csv(&self) -> DataFrame { + let ctx = SessionContext::new(); + ctx.read_csv("dummy", CsvReadOptions::new()).await.unwrap() + } + } +} diff --git a/datafusion/core/src/execution/context/json.rs b/datafusion/core/src/execution/context/json.rs new file mode 100644 index 000000000000..f67693aa8f31 --- /dev/null +++ b/datafusion/core/src/execution/context/json.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::datasource::physical_plan::plan_to_json; + +use super::super::options::{NdJsonReadOptions, ReadOptions}; +use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; + +impl SessionContext { + /// Creates a [`DataFrame`] for reading an JSON data source. + /// + /// For more control such as reading multiple files, you can use + /// [`read_table`](Self::read_table) with a [`super::ListingTable`]. + /// + /// For an example, see [`read_csv`](Self::read_csv) + pub async fn read_json( + &self, + table_paths: P, + options: NdJsonReadOptions<'_>, + ) -> Result { + self._read_type(table_paths, options).await + } + + /// Registers a JSON file as a table that it can be referenced + /// from SQL statements executed against this context. + pub async fn register_json( + &self, + name: &str, + table_path: &str, + options: NdJsonReadOptions<'_>, + ) -> Result<()> { + let listing_options = options.to_listing_options(&self.copied_config()); + + self.register_listing_table( + name, + table_path, + listing_options, + options.schema.map(|s| Arc::new(s.to_owned())), + None, + ) + .await?; + Ok(()) + } + + /// Executes a query and writes the results to a partitioned JSON file. + pub async fn write_json( + &self, + plan: Arc, + path: impl AsRef, + ) -> Result<()> { + plan_to_json(self.task_ctx(), plan, path).await + } +} diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context/mod.rs similarity index 81% rename from datafusion/core/src/execution/context.rs rename to datafusion/core/src/execution/context/mod.rs index 6cfb73a5109a..d6b7f046f3e3 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -16,11 +16,18 @@ // under the License. //! [`SessionContext`] contains methods for registering data sources and executing queries + +mod avro; +mod csv; +mod json; +#[cfg(feature = "parquet")] +mod parquet; + use crate::{ catalog::{CatalogList, MemoryCatalogList}, datasource::{ + function::{TableFunction, TableFunctionImpl}, listing::{ListingOptions, ListingTable}, - listing_table_factory::ListingTableFactory, provider::TableProviderFactory, }, datasource::{MemTable, ViewTable}, @@ -30,13 +37,13 @@ use crate::{ }; use datafusion_common::{ alias::AliasGenerator, - exec_err, not_impl_err, plan_err, + exec_err, not_impl_err, plan_datafusion_err, plan_err, tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, - StringifiedPlan, UserDefinedLogicalNode, WindowUDF, + Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; @@ -77,7 +84,6 @@ use datafusion_sql::{ use sqlparser::dialect::dialect_from_str; use crate::config::ConfigOptions; -use crate::datasource::physical_plan::{plan_to_csv, plan_to_json, plan_to_parquet}; use crate::execution::{runtime_env::RuntimeEnv, FunctionRegistry}; use crate::physical_plan::udaf::AggregateUDF; use crate::physical_plan::udf::ScalarUDF; @@ -92,7 +98,6 @@ use datafusion_sql::{ parser::DFParser, planner::{ContextProvider, SqlToRel}, }; -use parquet::file::properties::WriterProperties; use url::Url; use crate::catalog::information_schema::{InformationSchemaProvider, INFORMATION_SCHEMA}; @@ -106,13 +111,12 @@ use datafusion_sql::planner::object_name_to_table_reference; use uuid::Uuid; // backwards compatibility +use crate::datasource::provider::DefaultTableFactory; use crate::execution::options::ArrowReadOptions; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; -use super::options::{ - AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, ReadOptions, -}; +use super::options::ReadOptions; /// DataFilePaths adds a method to convert strings and vector of strings to vector of [`ListingTableUrl`] URLs. /// This allows methods such [`SessionContext::read_csv`] and [`SessionContext::read_avro`] @@ -258,7 +262,7 @@ impl Default for SessionContext { impl SessionContext { /// Creates a new `SessionContext` using the default [`SessionConfig`]. pub fn new() -> Self { - Self::with_config(SessionConfig::new()) + Self::new_with_config(SessionConfig::new()) } /// Finds any [`ListingSchemaProvider`]s and instructs them to reload tables from "disk" @@ -284,11 +288,18 @@ impl SessionContext { /// Creates a new `SessionContext` using the provided /// [`SessionConfig`] and a new [`RuntimeEnv`]. /// - /// See [`Self::with_config_rt`] for more details on resource + /// See [`Self::new_with_config_rt`] for more details on resource /// limits. - pub fn with_config(config: SessionConfig) -> Self { + pub fn new_with_config(config: SessionConfig) -> Self { let runtime = Arc::new(RuntimeEnv::default()); - Self::with_config_rt(config, runtime) + Self::new_with_config_rt(config, runtime) + } + + /// Creates a new `SessionContext` using the provided + /// [`SessionConfig`] and a new [`RuntimeEnv`]. + #[deprecated(since = "32.0.0", note = "Use SessionContext::new_with_config")] + pub fn with_config(config: SessionConfig) -> Self { + Self::new_with_config(config) } /// Creates a new `SessionContext` using the provided @@ -304,13 +315,20 @@ impl SessionContext { /// memory used) across all DataFusion queries in a process, /// all `SessionContext`'s should be configured with the /// same `RuntimeEnv`. + pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self { + let state = SessionState::new_with_config_rt(config, runtime); + Self::new_with_state(state) + } + + /// Creates a new `SessionContext` using the provided + /// [`SessionConfig`] and a [`RuntimeEnv`]. + #[deprecated(since = "32.0.0", note = "Use SessionState::new_with_config_rt")] pub fn with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - let state = SessionState::with_config_rt(config, runtime); - Self::with_state(state) + Self::new_with_config_rt(config, runtime) } /// Creates a new `SessionContext` using the provided [`SessionState`] - pub fn with_state(state: SessionState) -> Self { + pub fn new_with_state(state: SessionState) -> Self { Self { session_id: state.session_id.clone(), session_start_time: Utc::now(), @@ -318,6 +336,11 @@ impl SessionContext { } } + /// Creates a new `SessionContext` using the provided [`SessionState`] + #[deprecated(since = "32.0.0", note = "Use SessionState::new_with_state")] + pub fn with_state(state: SessionState) -> Self { + Self::new_with_state(state) + } /// Returns the time this `SessionContext` was created pub fn session_start_time(&self) -> DateTime { self.session_start_time @@ -507,6 +530,7 @@ impl SessionContext { if_not_exists, or_replace, constraints, + column_defaults, } = cmd; let input = Arc::try_unwrap(input).unwrap_or_else(|e| e.as_ref().clone()); @@ -520,7 +544,12 @@ impl SessionContext { let physical = DataFrame::new(self.state(), input); let batches: Vec<_> = physical.collect_partitioned().await?; - let table = Arc::new(MemTable::try_new(schema, batches)?); + let table = Arc::new( + // pass constraints and column defaults to the mem table. + MemTable::try_new(schema, batches)? + .with_constraints(constraints) + .with_column_defaults(column_defaults.into_iter().collect()), + ); self.register_table(&name, table)?; self.return_empty_dataframe() @@ -535,8 +564,10 @@ impl SessionContext { let batches: Vec<_> = physical.collect_partitioned().await?; let table = Arc::new( - // pass constraints to the mem table. - MemTable::try_new(schema, batches)?.with_constraints(constraints), + // pass constraints and column defaults to the mem table. + MemTable::try_new(schema, batches)? + .with_constraints(constraints) + .with_column_defaults(column_defaults.into_iter().collect()), ); self.register_table(&name, table)?; @@ -773,6 +804,14 @@ impl SessionContext { .add_var_provider(variable_type, provider); } + /// Register a table UDF with this context + pub fn register_udtf(&self, name: &str, fun: Arc) { + self.state.write().table_functions.insert( + name.to_owned(), + Arc::new(TableFunction::new(name.to_owned(), fun)), + ); + } + /// Registers a scalar UDF within this context. /// /// Note in SQL queries, function names are looked up using @@ -780,11 +819,18 @@ impl SessionContext { /// /// - `SELECT MY_FUNC(x)...` will look for a function named `"my_func"` /// - `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"` + /// Any functions registered with the udf name or its aliases will be overwritten with this new function pub fn register_udf(&self, f: ScalarUDF) { - self.state - .write() + let mut state = self.state.write(); + let aliases = f.aliases(); + for alias in aliases { + state + .scalar_functions + .insert(alias.to_string(), Arc::new(f.clone())); + } + state .scalar_functions - .insert(f.name.clone(), Arc::new(f)); + .insert(f.name().to_string(), Arc::new(f)); } /// Registers an aggregate UDF within this context. @@ -798,7 +844,7 @@ impl SessionContext { self.state .write() .aggregate_functions - .insert(f.name.clone(), Arc::new(f)); + .insert(f.name().to_string(), Arc::new(f)); } /// Registers a window UDF within this context. @@ -812,7 +858,7 @@ impl SessionContext { self.state .write() .window_functions - .insert(f.name.clone(), Arc::new(f)); + .insert(f.name().to_string(), Arc::new(f)); } /// Creates a [`DataFrame`] for reading a data source. @@ -827,6 +873,25 @@ impl SessionContext { let table_paths = table_paths.to_urls()?; let session_config = self.copied_config(); let listing_options = options.to_listing_options(&session_config); + + let option_extension = listing_options.file_extension.clone(); + + if table_paths.is_empty() { + return exec_err!("No table paths were provided"); + } + + // check if the file extension matches the expected extension + for path in &table_paths { + let file_path = path.as_str(); + if !file_path.ends_with(option_extension.clone().as_str()) + && !path.is_collection() + { + return exec_err!( + "File path '{file_path}' does not match the expected extension '{option_extension}'" + ); + } + } + let resolved_schema = options .get_resolved_schema(&session_config, self.state(), table_paths[0].clone()) .await?; @@ -837,34 +902,6 @@ impl SessionContext { self.read_table(Arc::new(provider)) } - /// Creates a [`DataFrame`] for reading an Avro data source. - /// - /// For more control such as reading multiple files, you can use - /// [`read_table`](Self::read_table) with a [`ListingTable`]. - /// - /// For an example, see [`read_csv`](Self::read_csv) - pub async fn read_avro( - &self, - table_paths: P, - options: AvroReadOptions<'_>, - ) -> Result { - self._read_type(table_paths, options).await - } - - /// Creates a [`DataFrame`] for reading an JSON data source. - /// - /// For more control such as reading multiple files, you can use - /// [`read_table`](Self::read_table) with a [`ListingTable`]. - /// - /// For an example, see [`read_csv`](Self::read_csv) - pub async fn read_json( - &self, - table_paths: P, - options: NdJsonReadOptions<'_>, - ) -> Result { - self._read_type(table_paths, options).await - } - /// Creates a [`DataFrame`] for reading an Arrow data source. /// /// For more control such as reading multiple files, you can use @@ -887,48 +924,6 @@ impl SessionContext { )) } - /// Creates a [`DataFrame`] for reading a CSV data source. - /// - /// For more control such as reading multiple files, you can use - /// [`read_table`](Self::read_table) with a [`ListingTable`]. - /// - /// Example usage is given below: - /// - /// ``` - /// use datafusion::prelude::*; - /// # use datafusion::error::Result; - /// # #[tokio::main] - /// # async fn main() -> Result<()> { - /// let ctx = SessionContext::new(); - /// // You can read a single file using `read_csv` - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; - /// // you can also read multiple files: - /// let df = ctx.read_csv(vec!["tests/data/example.csv", "tests/data/example.csv"], CsvReadOptions::new()).await?; - /// # Ok(()) - /// # } - /// ``` - pub async fn read_csv( - &self, - table_paths: P, - options: CsvReadOptions<'_>, - ) -> Result { - self._read_type(table_paths, options).await - } - - /// Creates a [`DataFrame`] for reading a Parquet data source. - /// - /// For more control such as reading multiple files, you can use - /// [`read_table`](Self::read_table) with a [`ListingTable`]. - /// - /// For an example, see [`read_csv`](Self::read_csv) - pub async fn read_parquet( - &self, - table_paths: P, - options: ParquetReadOptions<'_>, - ) -> Result { - self._read_type(table_paths, options).await - } - /// Creates a [`DataFrame`] for a [`TableProvider`] such as a /// [`ListingTable`] or a custom user defined provider. pub fn read_table(&self, provider: Arc) -> Result { @@ -969,14 +964,9 @@ impl SessionContext { sql_definition: Option, ) -> Result<()> { let table_path = ListingTableUrl::parse(table_path)?; - let resolved_schema = match (provided_schema, options.infinite_source) { - (Some(s), _) => s, - (None, false) => options.infer_schema(&self.state(), &table_path).await?, - (None, true) => { - return plan_err!( - "Schema inference for infinite data sources is not supported." - ) - } + let resolved_schema = match provided_schema { + Some(s) => s, + None => options.infer_schema(&self.state(), &table_path).await?, }; let config = ListingTableConfig::new(table_path) .with_listing_options(options) @@ -989,91 +979,6 @@ impl SessionContext { Ok(()) } - /// Registers a CSV file as a table which can referenced from SQL - /// statements executed against this context. - pub async fn register_csv( - &self, - name: &str, - table_path: &str, - options: CsvReadOptions<'_>, - ) -> Result<()> { - let listing_options = options.to_listing_options(&self.copied_config()); - - self.register_listing_table( - name, - table_path, - listing_options, - options.schema.map(|s| Arc::new(s.to_owned())), - None, - ) - .await?; - - Ok(()) - } - - /// Registers a JSON file as a table that it can be referenced - /// from SQL statements executed against this context. - pub async fn register_json( - &self, - name: &str, - table_path: &str, - options: NdJsonReadOptions<'_>, - ) -> Result<()> { - let listing_options = options.to_listing_options(&self.copied_config()); - - self.register_listing_table( - name, - table_path, - listing_options, - options.schema.map(|s| Arc::new(s.to_owned())), - None, - ) - .await?; - Ok(()) - } - - /// Registers a Parquet file as a table that can be referenced from SQL - /// statements executed against this context. - pub async fn register_parquet( - &self, - name: &str, - table_path: &str, - options: ParquetReadOptions<'_>, - ) -> Result<()> { - let listing_options = options.to_listing_options(&self.state.read().config); - - self.register_listing_table( - name, - table_path, - listing_options, - options.schema.map(|s| Arc::new(s.to_owned())), - None, - ) - .await?; - Ok(()) - } - - /// Registers an Avro file as a table that can be referenced from - /// SQL statements executed against this context. - pub async fn register_avro( - &self, - name: &str, - table_path: &str, - options: AvroReadOptions<'_>, - ) -> Result<()> { - let listing_options = options.to_listing_options(&self.copied_config()); - - self.register_listing_table( - name, - table_path, - listing_options, - options.schema.map(|s| Arc::new(s.to_owned())), - None, - ) - .await?; - Ok(()) - } - /// Registers an Arrow file as a table that can be referenced from /// SQL statements executed against this context. pub async fn register_arrow( @@ -1249,34 +1154,6 @@ impl SessionContext { self.state().create_physical_plan(logical_plan).await } - /// Executes a query and writes the results to a partitioned CSV file. - pub async fn write_csv( - &self, - plan: Arc, - path: impl AsRef, - ) -> Result<()> { - plan_to_csv(self.task_ctx(), plan, path).await - } - - /// Executes a query and writes the results to a partitioned JSON file. - pub async fn write_json( - &self, - plan: Arc, - path: impl AsRef, - ) -> Result<()> { - plan_to_json(self.task_ctx(), plan, path).await - } - - /// Executes a query and writes the results to a partitioned Parquet file. - pub async fn write_parquet( - &self, - plan: Arc, - path: impl AsRef, - writer_properties: Option, - ) -> Result<()> { - plan_to_parquet(self.task_ctx(), plan, path, writer_properties).await - } - /// Get a new TaskContext to run in this session pub fn task_ctx(&self) -> Arc { Arc::new(TaskContext::from(self)) @@ -1368,6 +1245,8 @@ pub struct SessionState { query_planner: Arc, /// Collection of catalogs containing schemas and ultimately TableProviders catalog_list: Arc, + /// Table Functions + table_functions: HashMap>, /// Scalar functions that are registered with the context scalar_functions: HashMap>, /// Aggregate functions registered in the context @@ -1401,25 +1280,24 @@ impl Debug for SessionState { } } -/// Default session builder using the provided configuration -#[deprecated( - since = "23.0.0", - note = "See SessionContext::with_config() or SessionState::with_config_rt" -)] -pub fn default_session_builder(config: SessionConfig) -> SessionState { - SessionState::with_config_rt(config, Arc::new(RuntimeEnv::default())) -} - impl SessionState { /// Returns new [`SessionState`] using the provided /// [`SessionConfig`] and [`RuntimeEnv`]. - pub fn with_config_rt(config: SessionConfig, runtime: Arc) -> Self { + pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self { let catalog_list = Arc::new(MemoryCatalogList::new()) as Arc; - Self::with_config_rt_and_catalog_list(config, runtime, catalog_list) + Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list) } - /// Returns new SessionState using the provided configuration, runtime and catalog list. - pub fn with_config_rt_and_catalog_list( + /// Returns new [`SessionState`] using the provided + /// [`SessionConfig`] and [`RuntimeEnv`]. + #[deprecated(since = "32.0.0", note = "Use SessionState::new_with_config_rt")] + pub fn with_config_rt(config: SessionConfig, runtime: Arc) -> Self { + Self::new_with_config_rt(config, runtime) + } + + /// Returns new [`SessionState`] using the provided + /// [`SessionConfig`], [`RuntimeEnv`], and [`CatalogList`] + pub fn new_with_config_rt_and_catalog_list( config: SessionConfig, runtime: Arc, catalog_list: Arc, @@ -1429,12 +1307,13 @@ impl SessionState { // Create table_factories for all default formats let mut table_factories: HashMap> = HashMap::new(); - table_factories.insert("PARQUET".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("CSV".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("JSON".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("NDJSON".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("AVRO".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("ARROW".into(), Arc::new(ListingTableFactory::new())); + #[cfg(feature = "parquet")] + table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); if config.create_default_catalog_and_schema() { let default_catalog = MemoryCatalogProvider::new(); @@ -1466,6 +1345,7 @@ impl SessionState { physical_optimizers: PhysicalOptimizer::new(), query_planner: Arc::new(DefaultQueryPlanner {}), catalog_list, + table_functions: HashMap::new(), scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), @@ -1476,7 +1356,19 @@ impl SessionState { table_factories, } } - + /// Returns new [`SessionState`] using the provided + /// [`SessionConfig`] and [`RuntimeEnv`]. + #[deprecated( + since = "32.0.0", + note = "Use SessionState::new_with_config_rt_and_catalog_list" + )] + pub fn with_config_rt_and_catalog_list( + config: SessionConfig, + runtime: Arc, + catalog_list: Arc, + ) -> Self { + Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list) + } fn register_default_schema( config: &SessionConfig, table_factories: &HashMap>, @@ -1547,17 +1439,14 @@ impl SessionState { self.catalog_list .catalog(&resolved_ref.catalog) .ok_or_else(|| { - DataFusionError::Plan(format!( + plan_datafusion_err!( "failed to resolve catalog: {}", resolved_ref.catalog - )) + ) })? .schema(&resolved_ref.schema) .ok_or_else(|| { - DataFusionError::Plan(format!( - "failed to resolve schema: {}", - resolved_ref.schema - )) + plan_datafusion_err!("failed to resolve schema: {}", resolved_ref.schema) }) } @@ -1659,11 +1548,11 @@ impl SessionState { dialect: &str, ) -> Result { let dialect = dialect_from_str(dialect).ok_or_else(|| { - DataFusionError::Plan(format!( + plan_datafusion_err!( "Unsupported SQL dialect: {dialect}. Available dialects: \ Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ MsSQL, ClickHouse, BigQuery, Ansi." - )) + ) })?; let mut statements = DFParser::parse_sql_with_dialect(sql, dialect.as_ref())?; if statements.len() > 1 { @@ -1732,9 +1621,6 @@ impl SessionState { .0 .insert(ObjectName(vec![Ident::from(table.name.as_str())])); } - DFStatement::DescribeTableStmt(table) => { - visitor.insert(&table.table_name) - } DFStatement::CopyTo(CopyToStatement { source, target: _, @@ -1833,7 +1719,7 @@ impl SessionState { let mut stringified_plans = e.stringified_plans.clone(); // analyze & capture output of each rule - let analyzed_plan = match self.analyzer.execute_and_check( + let analyzer_result = self.analyzer.execute_and_check( e.plan.as_ref(), self.options(), |analyzed_plan, analyzer| { @@ -1841,7 +1727,8 @@ impl SessionState { let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name }; stringified_plans.push(analyzed_plan.to_stringified(plan_type)); }, - ) { + ); + let analyzed_plan = match analyzer_result { Ok(plan) => plan, Err(DataFusionError::Context(analyzer_name, err)) => { let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name }; @@ -1864,7 +1751,7 @@ impl SessionState { .push(analyzed_plan.to_stringified(PlanType::FinalAnalyzedLogicalPlan)); // optimize the child plan, capturing the output of each optimizer - let (plan, logical_optimization_succeeded) = match self.optimizer.optimize( + let optimized_plan = self.optimizer.optimize( &analyzed_plan, self, |optimized_plan, optimizer| { @@ -1872,7 +1759,8 @@ impl SessionState { let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; stringified_plans.push(optimized_plan.to_stringified(plan_type)); }, - ) { + ); + let (plan, logical_optimization_succeeded) = match optimized_plan { Ok(plan) => (Arc::new(plan), true), Err(DataFusionError::Context(optimizer_name, err)) => { let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; @@ -1987,12 +1875,28 @@ struct SessionContextProvider<'a> { } impl<'a> ContextProvider for SessionContextProvider<'a> { - fn get_table_provider(&self, name: TableReference) -> Result> { + fn get_table_source(&self, name: TableReference) -> Result> { let name = self.state.resolve_table_ref(name).to_string(); self.tables .get(&name) .cloned() - .ok_or_else(|| DataFusionError::Plan(format!("table '{name}' not found"))) + .ok_or_else(|| plan_datafusion_err!("table '{name}' not found")) + } + + fn get_table_function_source( + &self, + name: &str, + args: Vec, + ) -> Result> { + let tbl_func = self + .state + .table_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; + let provider = tbl_func.create_table_provider(&args)?; + + Ok(provider_as_source(provider)) } fn get_function_meta(&self, name: &str) -> Option> { @@ -2039,9 +1943,7 @@ impl FunctionRegistry for SessionState { let result = self.scalar_functions.get(name); result.cloned().ok_or_else(|| { - DataFusionError::Plan(format!( - "There is no UDF named \"{name}\" in the registry" - )) + plan_datafusion_err!("There is no UDF named \"{name}\" in the registry") }) } @@ -2049,9 +1951,7 @@ impl FunctionRegistry for SessionState { let result = self.aggregate_functions.get(name); result.cloned().ok_or_else(|| { - DataFusionError::Plan(format!( - "There is no UDAF named \"{name}\" in the registry" - )) + plan_datafusion_err!("There is no UDAF named \"{name}\" in the registry") }) } @@ -2059,9 +1959,7 @@ impl FunctionRegistry for SessionState { let result = self.window_functions.get(name); result.cloned().ok_or_else(|| { - DataFusionError::Plan(format!( - "There is no UDWF named \"{name}\" in the registry" - )) + plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry") }) } } @@ -2217,25 +2115,21 @@ impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { #[cfg(test)] mod tests { + use super::super::options::CsvReadOptions; use super::*; use crate::assert_batches_eq; use crate::execution::context::QueryPlanner; use crate::execution::memory_pool::MemoryConsumer; use crate::execution::runtime_env::RuntimeConfig; - use crate::physical_plan::expressions::AvgAccumulator; use crate::test; - use crate::test_util::parquet_test_data; + use crate::test_util::{plan_and_collect, populate_csv_partitions}; use crate::variable::VarType; - use arrow::array::ArrayRef; - use arrow::record_batch::RecordBatch; - use arrow_schema::{Field, Schema}; + use arrow_schema::Schema; use async_trait::async_trait; - use datafusion_expr::{create_udaf, create_udf, Expr, Volatility}; - use datafusion_physical_expr::functions::make_scalar_function; - use std::fs::File; + use datafusion_expr::Expr; + use std::env; use std::path::PathBuf; use std::sync::Weak; - use std::{env, io::prelude::*}; use tempfile::TempDir; #[tokio::test] @@ -2253,7 +2147,7 @@ mod tests { let disk_manager = ctx1.runtime_env().disk_manager.clone(); let ctx2 = - SessionContext::with_config_rt(SessionConfig::new(), ctx1.runtime_env()); + SessionContext::new_with_config_rt(SessionConfig::new(), ctx1.runtime_env()); assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 100); assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 100); @@ -2330,120 +2224,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); - let myfunc = make_scalar_function(myfunc); - - ctx.register_udf(create_udf( - "MY_FUNC", - vec![DataType::Int32], - Arc::new(DataType::Int32), - Volatility::Immutable, - myfunc, - )); - - // doesn't work as it was registered with non lowercase - let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t") - .await - .unwrap_err(); - assert!(err - .to_string() - .contains("Error during planning: Invalid function \'my_func\'")); - - // Can call it if you put quotes - let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?; - - let expected = [ - "+--------------+", - "| MY_FUNC(t.i) |", - "+--------------+", - "| 1 |", - "+--------------+", - ]; - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - // Note capitalization - let my_avg = create_udaf( - "MY_AVG", - vec![DataType::Float64], - Arc::new(DataType::Float64), - Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), - Arc::new(vec![DataType::UInt64, DataType::Float64]), - ); - - ctx.register_udaf(my_avg); - - // doesn't work as it was registered as non lowercase - let err = plan_and_collect(&ctx, "SELECT MY_AVG(i) FROM t") - .await - .unwrap_err(); - assert!(err - .to_string() - .contains("Error during planning: Invalid function \'my_avg\'")); - - // Can call it if you put quotes - let result = plan_and_collect(&ctx, "SELECT \"MY_AVG\"(i) FROM t").await?; - - let expected = [ - "+-------------+", - "| MY_AVG(t.i) |", - "+-------------+", - "| 1.0 |", - "+-------------+", - ]; - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - async fn query_csv_with_custom_partition_extension() -> Result<()> { - let tmp_dir = TempDir::new()?; - - // The main stipulation of this test: use a file extension that isn't .csv. - let file_extension = ".tst"; - - let ctx = SessionContext::new(); - let schema = populate_csv_partitions(&tmp_dir, 2, file_extension)?; - ctx.register_csv( - "test", - tmp_dir.path().to_str().unwrap(), - CsvReadOptions::new() - .schema(&schema) - .file_extension(file_extension), - ) - .await?; - let results = - plan_and_collect(&ctx, "SELECT SUM(c1), SUM(c2), COUNT(*) FROM test").await?; - - assert_eq!(results.len(), 1); - let expected = [ - "+--------------+--------------+----------+", - "| SUM(test.c1) | SUM(test.c2) | COUNT(*) |", - "+--------------+--------------+----------+", - "| 10 | 110 | 20 |", - "+--------------+--------------+----------+", - ]; - assert_batches_eq!(expected, &results); - - Ok(()) - } - #[tokio::test] async fn send_context_to_threads() -> Result<()> { // ensure SessionContexts can be used in a multi-threaded @@ -2481,8 +2261,8 @@ mod tests { .set_str("datafusion.catalog.location", url.as_str()) .set_str("datafusion.catalog.format", "CSV") .set_str("datafusion.catalog.has_header", "true"); - let session_state = SessionState::with_config_rt(cfg, runtime); - let ctx = SessionContext::with_state(session_state); + let session_state = SessionState::new_with_config_rt(cfg, runtime); + let ctx = SessionContext::new_with_state(session_state); ctx.refresh_catalogs().await?; let result = @@ -2507,9 +2287,10 @@ mod tests { #[tokio::test] async fn custom_query_planner() -> Result<()> { let runtime = Arc::new(RuntimeEnv::default()); - let session_state = SessionState::with_config_rt(SessionConfig::new(), runtime) - .with_query_planner(Arc::new(MyQueryPlanner {})); - let ctx = SessionContext::with_state(session_state); + let session_state = + SessionState::new_with_config_rt(SessionConfig::new(), runtime) + .with_query_planner(Arc::new(MyQueryPlanner {})); + let ctx = SessionContext::new_with_state(session_state); let df = ctx.sql("SELECT 1").await?; df.collect().await.expect_err("query not supported"); @@ -2518,7 +2299,7 @@ mod tests { #[tokio::test] async fn disabled_default_catalog_and_schema() -> Result<()> { - let ctx = SessionContext::with_config( + let ctx = SessionContext::new_with_config( SessionConfig::new().with_create_default_catalog_and_schema(false), ); @@ -2561,7 +2342,7 @@ mod tests { } async fn catalog_and_schema_test(config: SessionConfig) { - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); schema @@ -2638,7 +2419,7 @@ mod tests { #[tokio::test] async fn catalogs_not_leaked() { // the information schema used to introduce cyclic Arcs - let ctx = SessionContext::with_config( + let ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -2661,7 +2442,7 @@ mod tests { #[tokio::test] async fn sql_create_schema() -> Result<()> { // the information schema used to introduce cyclic Arcs - let ctx = SessionContext::with_config( + let ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -2684,7 +2465,7 @@ mod tests { #[tokio::test] async fn sql_create_catalog() -> Result<()> { // the information schema used to introduce cyclic Arcs - let ctx = SessionContext::with_config( + let ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -2707,60 +2488,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn read_with_glob_path() -> Result<()> { - let ctx = SessionContext::new(); - - let df = ctx - .read_parquet( - format!("{}/alltypes_plain*.parquet", parquet_test_data()), - ParquetReadOptions::default(), - ) - .await?; - let results = df.collect().await?; - let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); - // alltypes_plain.parquet = 8 rows, alltypes_plain.snappy.parquet = 2 rows, alltypes_dictionary.parquet = 2 rows - assert_eq!(total_rows, 10); - Ok(()) - } - - #[tokio::test] - async fn read_with_glob_path_issue_2465() -> Result<()> { - let ctx = SessionContext::new(); - - let df = ctx - .read_parquet( - // it was reported that when a path contains // (two consecutive separator) no files were found - // in this test, regardless of parquet_test_data() value, our path now contains a // - format!("{}/..//*/alltypes_plain*.parquet", parquet_test_data()), - ParquetReadOptions::default(), - ) - .await?; - let results = df.collect().await?; - let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); - // alltypes_plain.parquet = 8 rows, alltypes_plain.snappy.parquet = 2 rows, alltypes_dictionary.parquet = 2 rows - assert_eq!(total_rows, 10); - Ok(()) - } - - #[tokio::test] - async fn read_from_registered_table_with_glob_path() -> Result<()> { - let ctx = SessionContext::new(); - - ctx.register_parquet( - "test", - &format!("{}/alltypes_plain*.parquet", parquet_test_data()), - ParquetReadOptions::default(), - ) - .await?; - let df = ctx.sql("SELECT * FROM test").await?; - let results = df.collect().await?; - let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); - // alltypes_plain.parquet = 8 rows, alltypes_plain.snappy.parquet = 2 rows, alltypes_dictionary.parquet = 2 rows - assert_eq!(total_rows, 10); - Ok(()) - } - struct MyPhysicalPlanner {} #[async_trait] @@ -2800,50 +2527,14 @@ mod tests { } } - /// Execute SQL and return results - async fn plan_and_collect( - ctx: &SessionContext, - sql: &str, - ) -> Result> { - ctx.sql(sql).await?.collect().await - } - - /// Generate CSV partitions within the supplied directory - fn populate_csv_partitions( - tmp_dir: &TempDir, - partition_count: usize, - file_extension: &str, - ) -> Result { - // define schema for data source (csv file) - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::UInt32, false), - Field::new("c2", DataType::UInt64, false), - Field::new("c3", DataType::Boolean, false), - ])); - - // generate a partitioned file - for partition in 0..partition_count { - let filename = format!("partition-{partition}.{file_extension}"); - let file_path = tmp_dir.path().join(filename); - let mut file = File::create(file_path)?; - - // generate some data - for i in 0..=10 { - let data = format!("{},{},{}\n", partition, i, i % 2 == 0); - file.write_all(data.as_bytes())?; - } - } - - Ok(schema) - } - /// Generate a partitioned CSV file and register it with an execution context async fn create_ctx( tmp_dir: &TempDir, partition_count: usize, ) -> Result { - let ctx = - SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(8), + ); let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?; @@ -2857,37 +2548,4 @@ mod tests { Ok(ctx) } - - // Test for compilation error when calling read_* functions from an #[async_trait] function. - // See https://github.com/apache/arrow-datafusion/issues/1154 - #[async_trait] - trait CallReadTrait { - async fn call_read_csv(&self) -> DataFrame; - async fn call_read_avro(&self) -> DataFrame; - async fn call_read_parquet(&self) -> DataFrame; - } - - struct CallRead {} - - #[async_trait] - impl CallReadTrait for CallRead { - async fn call_read_csv(&self) -> DataFrame { - let ctx = SessionContext::new(); - ctx.read_csv("dummy", CsvReadOptions::new()).await.unwrap() - } - - async fn call_read_avro(&self) -> DataFrame { - let ctx = SessionContext::new(); - ctx.read_avro("dummy", AvroReadOptions::default()) - .await - .unwrap() - } - - async fn call_read_parquet(&self) -> DataFrame { - let ctx = SessionContext::new(); - ctx.read_parquet("dummy", ParquetReadOptions::default()) - .await - .unwrap() - } - } } diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs new file mode 100644 index 000000000000..7825d9b88297 --- /dev/null +++ b/datafusion/core/src/execution/context/parquet.rs @@ -0,0 +1,364 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::datasource::physical_plan::parquet::plan_to_parquet; +use parquet::file::properties::WriterProperties; + +use super::super::options::{ParquetReadOptions, ReadOptions}; +use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; + +impl SessionContext { + /// Creates a [`DataFrame`] for reading a Parquet data source. + /// + /// For more control such as reading multiple files, you can use + /// [`read_table`](Self::read_table) with a [`super::ListingTable`]. + /// + /// For an example, see [`read_csv`](Self::read_csv) + pub async fn read_parquet( + &self, + table_paths: P, + options: ParquetReadOptions<'_>, + ) -> Result { + self._read_type(table_paths, options).await + } + + /// Registers a Parquet file as a table that can be referenced from SQL + /// statements executed against this context. + pub async fn register_parquet( + &self, + name: &str, + table_path: &str, + options: ParquetReadOptions<'_>, + ) -> Result<()> { + let listing_options = options.to_listing_options(&self.state.read().config); + + self.register_listing_table( + name, + table_path, + listing_options, + options.schema.map(|s| Arc::new(s.to_owned())), + None, + ) + .await?; + Ok(()) + } + + /// Executes a query and writes the results to a partitioned Parquet file. + pub async fn write_parquet( + &self, + plan: Arc, + path: impl AsRef, + writer_properties: Option, + ) -> Result<()> { + plan_to_parquet(self.task_ctx(), plan, path, writer_properties).await + } +} + +#[cfg(test)] +mod tests { + use async_trait::async_trait; + + use crate::arrow::array::{Float32Array, Int32Array}; + use crate::arrow::datatypes::{DataType, Field, Schema}; + use crate::arrow::record_batch::RecordBatch; + use crate::dataframe::DataFrameWriteOptions; + use crate::parquet::basic::Compression; + use crate::test_util::parquet_test_data; + use datafusion_execution::config::SessionConfig; + use tempfile::tempdir; + + use super::*; + + #[tokio::test] + async fn read_with_glob_path() -> Result<()> { + let ctx = SessionContext::new(); + + let df = ctx + .read_parquet( + format!("{}/alltypes_plain*.parquet", parquet_test_data()), + ParquetReadOptions::default(), + ) + .await?; + let results = df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + // alltypes_plain.parquet = 8 rows, alltypes_plain.snappy.parquet = 2 rows, alltypes_dictionary.parquet = 2 rows + assert_eq!(total_rows, 10); + Ok(()) + } + + #[tokio::test] + async fn read_with_glob_path_issue_2465() -> Result<()> { + let config = + SessionConfig::from_string_hash_map(std::collections::HashMap::from([( + "datafusion.execution.listing_table_ignore_subdirectory".to_owned(), + "false".to_owned(), + )]))?; + let ctx = SessionContext::new_with_config(config); + let df = ctx + .read_parquet( + // it was reported that when a path contains // (two consecutive separator) no files were found + // in this test, regardless of parquet_test_data() value, our path now contains a // + format!("{}/..//*/alltypes_plain*.parquet", parquet_test_data()), + ParquetReadOptions::default(), + ) + .await?; + let results = df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + // alltypes_plain.parquet = 8 rows, alltypes_plain.snappy.parquet = 2 rows, alltypes_dictionary.parquet = 2 rows + assert_eq!(total_rows, 10); + Ok(()) + } + + #[tokio::test] + async fn read_from_registered_table_with_glob_path() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_parquet( + "test", + &format!("{}/alltypes_plain*.parquet", parquet_test_data()), + ParquetReadOptions::default(), + ) + .await?; + let df = ctx.sql("SELECT * FROM test").await?; + let results = df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + // alltypes_plain.parquet = 8 rows, alltypes_plain.snappy.parquet = 2 rows, alltypes_dictionary.parquet = 2 rows + assert_eq!(total_rows, 10); + Ok(()) + } + + #[tokio::test] + async fn read_from_different_file_extension() -> Result<()> { + let ctx = SessionContext::new(); + let sep = std::path::MAIN_SEPARATOR.to_string(); + + // Make up a new dataframe. + let write_df = ctx.read_batch(RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("purchase_id", DataType::Int32, false), + Field::new("price", DataType::Float32, false), + Field::new("quantity", DataType::Int32, false), + ])), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])), + Arc::new(Float32Array::from(vec![1.12, 3.40, 2.33, 9.10, 6.66])), + Arc::new(Int32Array::from(vec![1, 3, 2, 4, 3])), + ], + )?)?; + + let temp_dir = tempdir()?; + let temp_dir_path = temp_dir.path(); + let path1 = temp_dir_path + .join("output1.parquet") + .to_str() + .unwrap() + .to_string(); + let path2 = temp_dir_path + .join("output2.parquet.snappy") + .to_str() + .unwrap() + .to_string(); + let path3 = temp_dir_path + .join("output3.parquet.snappy.parquet") + .to_str() + .unwrap() + .to_string(); + + let path4 = temp_dir_path + .join("output4.parquet".to_owned() + &sep) + .to_str() + .unwrap() + .to_string(); + + let path5 = temp_dir_path + .join("bbb..bbb") + .join("filename.parquet") + .to_str() + .unwrap() + .to_string(); + let dir = temp_dir_path + .join("bbb..bbb".to_owned() + &sep) + .to_str() + .unwrap() + .to_string(); + std::fs::create_dir(dir).expect("create dir failed"); + + // Write the dataframe to a parquet file named 'output1.parquet' + write_df + .clone() + .write_parquet( + &path1, + DataFrameWriteOptions::new().with_single_file_output(true), + Some( + WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(), + ), + ) + .await?; + + // Write the dataframe to a parquet file named 'output2.parquet.snappy' + write_df + .clone() + .write_parquet( + &path2, + DataFrameWriteOptions::new().with_single_file_output(true), + Some( + WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(), + ), + ) + .await?; + + // Write the dataframe to a parquet file named 'output3.parquet.snappy.parquet' + write_df + .clone() + .write_parquet( + &path3, + DataFrameWriteOptions::new().with_single_file_output(true), + Some( + WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(), + ), + ) + .await?; + + // Write the dataframe to a parquet file named 'bbb..bbb/filename.parquet' + write_df + .write_parquet( + &path5, + DataFrameWriteOptions::new().with_single_file_output(true), + Some( + WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(), + ), + ) + .await?; + + // Read the dataframe from 'output1.parquet' with the default file extension. + let read_df = ctx + .read_parquet( + &path1, + ParquetReadOptions { + ..Default::default() + }, + ) + .await?; + + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 5); + + // Read the dataframe from 'output2.parquet.snappy' with the correct file extension. + let read_df = ctx + .read_parquet( + &path2, + ParquetReadOptions { + file_extension: "snappy", + ..Default::default() + }, + ) + .await?; + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 5); + + // Read the dataframe from 'output3.parquet.snappy.parquet' with the wrong file extension. + let read_df = ctx + .read_parquet( + &path2, + ParquetReadOptions { + ..Default::default() + }, + ) + .await; + let binding = DataFilePaths::to_urls(&path2).unwrap(); + let expexted_path = binding[0].as_str(); + assert_eq!( + read_df.unwrap_err().strip_backtrace(), + format!("Execution error: File path '{}' does not match the expected extension '.parquet'", expexted_path) + ); + + // Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension. + let read_df = ctx + .read_parquet( + &path3, + ParquetReadOptions { + ..Default::default() + }, + ) + .await?; + + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 5); + + // Read the dataframe from 'output4/' + std::fs::create_dir(&path4)?; + let read_df = ctx + .read_parquet( + &path4, + ParquetReadOptions { + ..Default::default() + }, + ) + .await?; + + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 0); + + // Read the datafram from doule dot folder; + let read_df = ctx + .read_parquet( + &path5, + ParquetReadOptions { + ..Default::default() + }, + ) + .await?; + + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 5); + Ok(()) + } + + // Test for compilation error when calling read_* functions from an #[async_trait] function. + // See https://github.com/apache/arrow-datafusion/issues/1154 + #[async_trait] + trait CallReadTrait { + async fn call_read_parquet(&self) -> DataFrame; + } + + struct CallRead {} + + #[async_trait] + impl CallReadTrait for CallRead { + async fn call_read_parquet(&self) -> DataFrame { + let ctx = SessionContext::new(); + ctx.read_parquet("dummy", ParquetReadOptions::default()) + .await + .unwrap() + } + } +} diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 576f66a5ed7c..8fc724a22443 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -132,18 +132,19 @@ //! //! ## Customization and Extension //! -//! DataFusion is designed to be a "disaggregated" query engine. This -//! means that developers can mix and extend the parts of DataFusion -//! they need for their usecase. For example, just the -//! [`ExecutionPlan`] operators, or the [`SqlToRel`] SQL planner and -//! optimizer. +//! DataFusion is a "disaggregated" query engine. This +//! means developers can start with a working, full featured engine, and then +//! extend the parts of DataFusion they need to specialize for their usecase. For example, +//! some projects may add custom [`ExecutionPlan`] operators, or create their own +//! query language that directly creates [`LogicalPlan`] rather than using the +//! built in SQL planner, [`SqlToRel`]. //! //! In order to achieve this, DataFusion supports extension at many points: //! //! * read from any datasource ([`TableProvider`]) //! * define your own catalogs, schemas, and table lists ([`CatalogProvider`]) -//! * build your own query langue or plans using the ([`LogicalPlanBuilder`]) -//! * declare and use user-defined functions: ([`ScalarUDF`], and [`AggregateUDF`]) +//! * build your own query language or plans ([`LogicalPlanBuilder`]) +//! * declare and use user-defined functions ([`ScalarUDF`], and [`AggregateUDF`], [`WindowUDF`]) //! * add custom optimizer rewrite passes ([`OptimizerRule`] and [`PhysicalOptimizerRule`]) //! * extend the planner to use user-defined logical and physical nodes ([`QueryPlanner`]) //! @@ -152,8 +153,9 @@ //! [`TableProvider`]: crate::datasource::TableProvider //! [`CatalogProvider`]: crate::catalog::CatalogProvider //! [`LogicalPlanBuilder`]: datafusion_expr::logical_plan::builder::LogicalPlanBuilder -//! [`ScalarUDF`]: physical_plan::udf::ScalarUDF -//! [`AggregateUDF`]: physical_plan::udaf::AggregateUDF +//! [`ScalarUDF`]: crate::logical_expr::ScalarUDF +//! [`AggregateUDF`]: crate::logical_expr::AggregateUDF +//! [`WindowUDF`]: crate::logical_expr::WindowUDF //! [`QueryPlanner`]: execution::context::QueryPlanner //! [`OptimizerRule`]: datafusion_optimizer::optimizer::OptimizerRule //! [`PhysicalOptimizerRule`]: crate::physical_optimizer::optimizer::PhysicalOptimizerRule @@ -279,33 +281,42 @@ //! [`MemTable`]: crate::datasource::memory::MemTable //! [`StreamingTable`]: crate::datasource::streaming::StreamingTable //! -//! ## Plans +//! ## Plan Representations //! -//! Logical planning yields [`LogicalPlan`]s nodes and [`Expr`] -//! expressions which are [`Schema`] aware and represent statements +//! ### Logical Plans +//! Logical planning yields [`LogicalPlan`] nodes and [`Expr`] +//! representing expressions which are [`Schema`] aware and represent statements //! independent of how they are physically executed. //! A [`LogicalPlan`] is a Directed Acyclic Graph (DAG) of other //! [`LogicalPlan`]s, each potentially containing embedded [`Expr`]s. //! -//! An [`ExecutionPlan`] (sometimes referred to as a "physical plan") -//! is a plan that can be executed against data. It a DAG of other -//! [`ExecutionPlan`]s each potentially containing expressions of the -//! following types: +//! [`Expr`]s can be rewritten using the [`TreeNode`] API and simplified using +//! [`ExprSimplifier`]. Examples of working with and executing `Expr`s can be found in the +//! [`expr_api`.rs] example //! -//! 1. [`PhysicalExpr`]: Scalar functions +//! [`TreeNode`]: datafusion_common::tree_node::TreeNode +//! [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier +//! [`expr_api`.rs]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs //! -//! 2. [`AggregateExpr`]: Aggregate functions +//! ### Physical Plans //! -//! 2. [`WindowExpr`]: Window functions +//! An [`ExecutionPlan`] (sometimes referred to as a "physical plan") +//! is a plan that can be executed against data. It a DAG of other +//! [`ExecutionPlan`]s each potentially containing expressions that implement the +//! [`PhysicalExpr`] trait. //! -//! Compared to a [`LogicalPlan`], an [`ExecutionPlan`] has concrete +//! Compared to a [`LogicalPlan`], an [`ExecutionPlan`] has additional concrete //! information about how to perform calculations (e.g. hash vs merge //! join), and how data flows during execution (e.g. partitioning and //! sortedness). //! +//! [cp_solver] performs range propagation analysis on [`PhysicalExpr`]s and +//! [`PruningPredicate`] can prove certain boolean [`PhysicalExpr`]s used for +//! filtering can never be `true` using additional statistical information. +//! +//! [cp_solver]: crate::physical_expr::intervals::cp_solver +//! [`PruningPredicate`]: crate::physical_optimizer::pruning::PruningPredicate //! [`PhysicalExpr`]: crate::physical_plan::PhysicalExpr -//! [`AggregateExpr`]: crate::physical_plan::AggregateExpr -//! [`WindowExpr`]: crate::physical_plan::WindowExpr //! //! ## Execution //! @@ -435,6 +446,7 @@ pub mod variable; // re-export dependencies from arrow-rs to minimize version maintenance for crate users pub use arrow; +#[cfg(feature = "parquet")] pub use parquet; // re-export DataFusion sub-crates at the top level. Use `pub use *` diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 396e66972f30..86a8cdb7b3d4 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -18,27 +18,25 @@ //! Utilizing exact statistics from sources to avoid scanning data use std::sync::Arc; +use super::optimizer::PhysicalOptimizerRule; use crate::config::ConfigOptions; -use datafusion_common::tree_node::TreeNode; -use datafusion_expr::utils::COUNT_STAR_EXPANSION; - -use crate::physical_plan::aggregates::{AggregateExec, AggregateMode}; -use crate::physical_plan::empty::EmptyExec; +use crate::error::Result; +use crate::physical_plan::aggregates::AggregateExec; use crate::physical_plan::projection::ProjectionExec; -use crate::physical_plan::{ - expressions, AggregateExpr, ColumnStatistics, ExecutionPlan, Statistics, -}; +use crate::physical_plan::{expressions, AggregateExpr, ExecutionPlan, Statistics}; use crate::scalar::ScalarValue; -use super::optimizer::PhysicalOptimizerRule; -use crate::error::Result; +use datafusion_common::stats::Precision; +use datafusion_common::tree_node::TreeNode; +use datafusion_expr::utils::COUNT_STAR_EXPANSION; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; /// Optimizer that uses available statistics for aggregate functions #[derive(Default)] pub struct AggregateStatistics {} /// The name of the column corresponding to [`COUNT_STAR_EXPANSION`] -const COUNT_STAR_NAME: &str = "COUNT(UInt8(1))"; +const COUNT_STAR_NAME: &str = "COUNT(*)"; impl AggregateStatistics { #[allow(missing_docs)] @@ -58,7 +56,7 @@ impl PhysicalOptimizerRule for AggregateStatistics { .as_any() .downcast_ref::() .expect("take_optimizable() ensures that this is a AggregateExec"); - let stats = partial_agg_exec.input().statistics(); + let stats = partial_agg_exec.input().statistics()?; let mut projections = vec![]; for expr in partial_agg_exec.aggr_expr() { if let Some((non_null_rows, name)) = @@ -84,7 +82,7 @@ impl PhysicalOptimizerRule for AggregateStatistics { // input can be entirely removed Ok(Arc::new(ProjectionExec::try_new( projections, - Arc::new(EmptyExec::new(true, plan.schema())), + Arc::new(PlaceholderRowExec::new(plan.schema())), )?)) } else { plan.map_children(|child| self.optimize(child, _config)) @@ -107,13 +105,12 @@ impl PhysicalOptimizerRule for AggregateStatistics { /// assert if the node passed as argument is a final `AggregateExec` node that can be optimized: /// - its child (with possible intermediate layers) is a partial `AggregateExec` node /// - they both have no grouping expression -/// - the statistics are exact /// If this is the case, return a ref to the partial `AggregateExec`, else `None`. /// We would have preferred to return a casted ref to AggregateExec but the recursion requires /// the `ExecutionPlan.children()` method that returns an owned reference. fn take_optimizable(node: &dyn ExecutionPlan) -> Option> { if let Some(final_agg_exec) = node.as_any().downcast_ref::() { - if final_agg_exec.mode() == &AggregateMode::Final + if !final_agg_exec.mode().is_first_stage() && final_agg_exec.group_expr().is_empty() { let mut child = Arc::clone(final_agg_exec.input()); @@ -121,14 +118,11 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option> if let Some(partial_agg_exec) = child.as_any().downcast_ref::() { - if partial_agg_exec.mode() == &AggregateMode::Partial + if partial_agg_exec.mode().is_first_stage() && partial_agg_exec.group_expr().is_empty() && partial_agg_exec.filter_expr().iter().all(|e| e.is_none()) { - let stats = partial_agg_exec.input().statistics(); - if stats.is_exact { - return Some(child); - } + return Some(child); } } if let [ref childrens_child] = child.children().as_slice() { @@ -142,13 +136,13 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option> None } -/// If this agg_expr is a count that is defined in the statistics, return it +/// If this agg_expr is a count that is exactly defined in the statistics, return it. fn take_optimizable_table_count( agg_expr: &dyn AggregateExpr, stats: &Statistics, ) -> Option<(ScalarValue, &'static str)> { - if let (Some(num_rows), Some(casted_expr)) = ( - stats.num_rows, + if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( + &stats.num_rows, agg_expr.as_any().downcast_ref::(), ) { // TODO implementing Eq on PhysicalExpr would help a lot here @@ -169,14 +163,14 @@ fn take_optimizable_table_count( None } -/// If this agg_expr is a count that can be derived from the statistics, return it +/// If this agg_expr is a count that can be exactly derived from the statistics, return it. fn take_optimizable_column_count( agg_expr: &dyn AggregateExpr, stats: &Statistics, ) -> Option<(ScalarValue, String)> { - if let (Some(num_rows), Some(col_stats), Some(casted_expr)) = ( - stats.num_rows, - &stats.column_statistics, + let col_stats = &stats.column_statistics; + if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( + &stats.num_rows, agg_expr.as_any().downcast_ref::(), ) { if casted_expr.expressions().len() == 1 { @@ -185,11 +179,8 @@ fn take_optimizable_column_count( .as_any() .downcast_ref::() { - if let ColumnStatistics { - null_count: Some(val), - .. - } = &col_stats[col_expr.index()] - { + let current_val = &col_stats[col_expr.index()].null_count; + if let &Precision::Exact(val) = current_val { return Some(( ScalarValue::Int64(Some((num_rows - val) as i64)), casted_expr.name().to_string(), @@ -201,27 +192,23 @@ fn take_optimizable_column_count( None } -/// If this agg_expr is a min that is defined in the statistics, return it +/// If this agg_expr is a min that is exactly defined in the statistics, return it. fn take_optimizable_min( agg_expr: &dyn AggregateExpr, stats: &Statistics, ) -> Option<(ScalarValue, String)> { - if let (Some(col_stats), Some(casted_expr)) = ( - &stats.column_statistics, - agg_expr.as_any().downcast_ref::(), - ) { + let col_stats = &stats.column_statistics; + if let Some(casted_expr) = agg_expr.as_any().downcast_ref::() { if casted_expr.expressions().len() == 1 { // TODO optimize with exprs other than Column if let Some(col_expr) = casted_expr.expressions()[0] .as_any() .downcast_ref::() { - if let ColumnStatistics { - min_value: Some(val), - .. - } = &col_stats[col_expr.index()] - { - return Some((val.clone(), casted_expr.name().to_string())); + if let Precision::Exact(val) = &col_stats[col_expr.index()].min_value { + if !val.is_null() { + return Some((val.clone(), casted_expr.name().to_string())); + } } } } @@ -229,27 +216,23 @@ fn take_optimizable_min( None } -/// If this agg_expr is a max that is defined in the statistics, return it +/// If this agg_expr is a max that is exactly defined in the statistics, return it. fn take_optimizable_max( agg_expr: &dyn AggregateExpr, stats: &Statistics, ) -> Option<(ScalarValue, String)> { - if let (Some(col_stats), Some(casted_expr)) = ( - &stats.column_statistics, - agg_expr.as_any().downcast_ref::(), - ) { + let col_stats = &stats.column_statistics; + if let Some(casted_expr) = agg_expr.as_any().downcast_ref::() { if casted_expr.expressions().len() == 1 { // TODO optimize with exprs other than Column if let Some(col_expr) = casted_expr.expressions()[0] .as_any() .downcast_ref::() { - if let ColumnStatistics { - max_value: Some(val), - .. - } = &col_stats[col_expr.index()] - { - return Some((val.clone(), casted_expr.name().to_string())); + if let Precision::Exact(val) = &col_stats[col_expr.index()].max_value { + if !val.is_null() { + return Some((val.clone(), casted_expr.name().to_string())); + } } } } @@ -258,17 +241,10 @@ fn take_optimizable_max( } #[cfg(test)] -mod tests { - use super::*; +pub(crate) mod tests { use std::sync::Arc; - use arrow::array::Int32Array; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion_common::cast::as_int64_array; - use datafusion_physical_expr::expressions::cast; - use datafusion_physical_expr::PhysicalExpr; - + use super::*; use crate::error::Result; use crate::logical_expr::Operator; use crate::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; @@ -279,6 +255,14 @@ mod tests { use crate::physical_plan::memory::MemoryExec; use crate::prelude::SessionContext; + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_common::cast::as_int64_array; + use datafusion_physical_expr::expressions::cast; + use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_plan::aggregates::AggregateMode; + /// Mock data using a MemoryExec which has an exact count statistic fn mock_data() -> Result> { let schema = Arc::new(Schema::new(vec![ @@ -308,7 +292,8 @@ mod tests { ) -> Result<()> { let session_ctx = SessionContext::new(); let state = session_ctx.state(); - let plan = Arc::new(plan) as _; + let plan: Arc = Arc::new(plan); + let optimized = AggregateStatistics::new() .optimize(Arc::clone(&plan), state.config_options())?; @@ -349,7 +334,7 @@ mod tests { } /// Describe the type of aggregate being tested - enum TestAggregate { + pub(crate) enum TestAggregate { /// Testing COUNT(*) type aggregates CountStar, @@ -358,7 +343,7 @@ mod tests { } impl TestAggregate { - fn new_count_star() -> Self { + pub(crate) fn new_count_star() -> Self { Self::CountStar } @@ -367,7 +352,7 @@ mod tests { } /// Return appropriate expr depending if COUNT is for col or table (*) - fn count_expr(&self) -> Arc { + pub(crate) fn count_expr(&self) -> Arc { Arc::new(Count::new( self.column(), self.column_name(), @@ -412,7 +397,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], source, Arc::clone(&schema), )?; @@ -422,7 +406,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; @@ -444,7 +427,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], source, Arc::clone(&schema), )?; @@ -454,7 +436,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; @@ -475,7 +456,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], source, Arc::clone(&schema), )?; @@ -488,7 +468,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(coalesce), Arc::clone(&schema), )?; @@ -509,7 +488,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], source, Arc::clone(&schema), )?; @@ -522,7 +500,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(coalesce), Arc::clone(&schema), )?; @@ -554,7 +531,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], filter, Arc::clone(&schema), )?; @@ -564,7 +540,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; @@ -601,7 +576,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], filter, Arc::clone(&schema), )?; @@ -611,7 +585,6 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 40b2bcc3e140..7359a6463059 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -17,13 +17,15 @@ //! CombinePartialFinalAggregate optimizer rule checks the adjacent Partial and Final AggregateExecs //! and try to combine them if necessary + +use std::sync::Arc; + use crate::error::Result; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::ExecutionPlan; -use datafusion_common::config::ConfigOptions; -use std::sync::Arc; +use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{AggregateExpr, PhysicalExpr}; @@ -89,10 +91,12 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { input_agg_exec.group_by().clone(), input_agg_exec.aggr_expr().to_vec(), input_agg_exec.filter_expr().to_vec(), - input_agg_exec.order_by_expr().to_vec(), input_agg_exec.input().clone(), - input_agg_exec.input_schema().clone(), + input_agg_exec.input_schema(), ) + .map(|combined_agg| { + combined_agg.with_limit(agg_exec.limit()) + }) .ok() .map(Arc::new) } else { @@ -191,10 +195,6 @@ fn discard_column_index(group_expr: Arc) -> Arc { @@ -248,12 +252,11 @@ mod tests { object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), file_schema: schema.clone(), file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), + statistics: Statistics::new_unknown(schema), projection: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -272,7 +275,6 @@ mod tests { group_by, aggr_expr, vec![], - vec![], input, schema, ) @@ -292,7 +294,6 @@ mod tests { group_by, aggr_expr, vec![], - vec![], input, schema, ) @@ -426,4 +427,48 @@ mod tests { assert_optimized!(expected, plan); Ok(()) } + + #[test] + fn aggregations_with_limit_combined() -> Result<()> { + let schema = schema(); + let aggr_expr = vec![]; + + let groups: Vec<(Arc, String)> = + vec![(col("c", &schema)?, "c".to_string())]; + + let partial_group_by = PhysicalGroupBy::new_single(groups); + let partial_agg = partial_aggregate_exec( + parquet_exec(&schema), + partial_group_by, + aggr_expr.clone(), + ); + + let groups: Vec<(Arc, String)> = + vec![(col("c", &partial_agg.schema())?, "c".to_string())]; + let final_group_by = PhysicalGroupBy::new_single(groups); + + let schema = partial_agg.schema(); + let final_agg = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + final_group_by, + aggr_expr, + vec![], + partial_agg, + schema, + ) + .unwrap() + .with_limit(Some(5)), + ); + let plan: Arc = final_agg; + // should combine the Partial/Final AggregateExecs to a Single AggregateExec + // with the final limit preserved + let expected = &[ + "AggregateExec: mode=Single, gby=[c@2 as c], aggr=[], lim=[5]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", + ]; + + assert_optimized!(expected, plan); + Ok(()) + } } diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index b2a1a0338384..bf3f9ef0f3e6 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -21,14 +21,17 @@ //! according to the configuration), this rule increases partition counts in //! the physical plan. +use std::borrow::Cow; use std::fmt; use std::fmt::Formatter; use std::sync::Arc; +use super::output_requirements::OutputRequirementExec; use crate::config::ConfigOptions; -use crate::datasource::physical_plan::{CsvExec, ParquetExec}; -use crate::error::{DataFusionError, Result}; -use crate::physical_optimizer::utils::{add_sort_above, get_plan_string, ExecTree}; +use crate::error::Result; +use crate::physical_optimizer::utils::{ + is_coalesce_partitions, is_repartition, is_sort_preserving_merge, +}; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -37,25 +40,25 @@ use crate::physical_plan::joins::{ }; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; -use crate::physical_plan::sorts::sort::SortOptions; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::union::{can_interleave, InterleaveExec, UnionExec}; use crate::physical_plan::windows::WindowAggExec; -use crate::physical_plan::Partitioning; -use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; +use crate::physical_plan::{ + with_new_children_if_necessary, Distribution, ExecutionPlan, Partitioning, +}; -use datafusion_common::internal_err; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use arrow::compute::SortOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_expr::logical_plan::JoinType; -use datafusion_physical_expr::equivalence::EquivalenceProperties; use datafusion_physical_expr::expressions::{Column, NoOp}; -use datafusion_physical_expr::utils::{ - map_columns_before_projection, ordering_satisfy_requirement_concrete, -}; +use datafusion_physical_expr::utils::map_columns_before_projection; use datafusion_physical_expr::{ - expr_list_eq_strict_order, PhysicalExpr, PhysicalSortRequirement, + physical_exprs_equal, EquivalenceProperties, LexRequirementRef, PhysicalExpr, + PhysicalSortRequirement, }; -use datafusion_physical_plan::unbounded_output; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::windows::{get_best_fitting_window, BoundedWindowAggExec}; +use datafusion_physical_plan::{get_plan_string, unbounded_output}; use itertools::izip; @@ -213,9 +216,7 @@ impl PhysicalOptimizerRule for EnforceDistribution { distribution_context.transform_up(&|distribution_context| { ensure_distribution(distribution_context, config) })?; - - // If output ordering is not necessary, removes it - update_plan_to_remove_unnecessary_final_order(distribution_context) + Ok(distribution_context.plan) } fn name(&self) -> &str { @@ -258,7 +259,7 @@ impl PhysicalOptimizerRule for EnforceDistribution { /// 1) If the current plan is Partitioned HashJoin, SortMergeJoin, check whether the requirements can be satisfied by adjusting join keys ordering: /// Requirements can not be satisfied, clear the current requirements, generate new requirements(to pushdown) based on the current join keys, return the unchanged plan. /// Requirements is already satisfied, clear the current requirements, generate new requirements(to pushdown) based on the current join keys, return the unchanged plan. -/// Requirements can be satisfied by adjusting keys ordering, clear the current requiements, generate new requirements(to pushdown) based on the adjusted join keys, return the changed plan. +/// Requirements can be satisfied by adjusting keys ordering, clear the current requirements, generate new requirements(to pushdown) based on the adjusted join keys, return the changed plan. /// /// 2) If the current plan is Aggregation, check whether the requirements can be satisfied by adjusting group by keys ordering: /// Requirements can not be satisfied, clear all the requirements, return the unchanged plan. @@ -270,11 +271,12 @@ impl PhysicalOptimizerRule for EnforceDistribution { /// 5) For other types of operators, by default, pushdown the parent requirements to children. /// fn adjust_input_keys_ordering( - requirements: PlanWithKeyRequirements, + mut requirements: PlanWithKeyRequirements, ) -> Result> { let parent_required = requirements.required_key_ordering.clone(); let plan_any = requirements.plan.as_any(); - let transformed = if let Some(HashJoinExec { + + if let Some(HashJoinExec { left, right, on, @@ -289,7 +291,7 @@ fn adjust_input_keys_ordering( PartitionMode::Partitioned => { let join_constructor = |new_conditions: (Vec<(Column, Column)>, Vec)| { - Ok(Arc::new(HashJoinExec::try_new( + HashJoinExec::try_new( left.clone(), right.clone(), new_conditions.0, @@ -297,15 +299,17 @@ fn adjust_input_keys_ordering( join_type, PartitionMode::Partitioned, *null_equals_null, - )?) as Arc) + ) + .map(|e| Arc::new(e) as _) }; - Some(reorder_partitioned_join_keys( + reorder_partitioned_join_keys( requirements.plan.clone(), &parent_required, on, vec![], &join_constructor, - )?) + ) + .map(Transformed::Yes) } PartitionMode::CollectLeft => { let new_right_request = match join_type { @@ -323,15 +327,15 @@ fn adjust_input_keys_ordering( }; // Push down requirements to the right side - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![None, new_right_request], - }) + requirements.children[1].required_key_ordering = + new_right_request.unwrap_or(vec![]); + Ok(Transformed::Yes(requirements)) } PartitionMode::Auto => { // Can not satisfy, clear the current requirements and generate new empty requirements - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))) } } } else if let Some(CrossJoinExec { left, .. }) = @@ -339,14 +343,9 @@ fn adjust_input_keys_ordering( { let left_columns_len = left.schema().fields().len(); // Push down requirements to the right side - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![ - None, - shift_right_required(&parent_required, left_columns_len), - ], - }) + requirements.children[1].required_key_ordering = + shift_right_required(&parent_required, left_columns_len).unwrap_or_default(); + Ok(Transformed::Yes(requirements)) } else if let Some(SortMergeJoinExec { left, right, @@ -359,35 +358,40 @@ fn adjust_input_keys_ordering( { let join_constructor = |new_conditions: (Vec<(Column, Column)>, Vec)| { - Ok(Arc::new(SortMergeJoinExec::try_new( + SortMergeJoinExec::try_new( left.clone(), right.clone(), new_conditions.0, *join_type, new_conditions.1, *null_equals_null, - )?) as Arc) + ) + .map(|e| Arc::new(e) as _) }; - Some(reorder_partitioned_join_keys( + reorder_partitioned_join_keys( requirements.plan.clone(), &parent_required, on, sort_options.clone(), &join_constructor, - )?) + ) + .map(Transformed::Yes) } else if let Some(aggregate_exec) = plan_any.downcast_ref::() { if !parent_required.is_empty() { match aggregate_exec.mode() { - AggregateMode::FinalPartitioned => Some(reorder_aggregate_keys( + AggregateMode::FinalPartitioned => reorder_aggregate_keys( requirements.plan.clone(), &parent_required, aggregate_exec, - )?), - _ => Some(PlanWithKeyRequirements::new(requirements.plan.clone())), + ) + .map(Transformed::Yes), + _ => Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))), } } else { // Keep everything unchanged - None + Ok(Transformed::No(requirements)) } } else if let Some(proj) = plan_any.downcast_ref::() { let expr = proj.expr(); @@ -396,34 +400,28 @@ fn adjust_input_keys_ordering( // Construct a mapping from new name to the the orginal Column let new_required = map_columns_before_projection(&parent_required, expr); if new_required.len() == parent_required.len() { - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![Some(new_required.clone())], - }) + requirements.children[0].required_key_ordering = new_required; + Ok(Transformed::Yes(requirements)) } else { // Can not satisfy, clear the current requirements and generate new empty requirements - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))) } } else if plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() { - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))) } else { // By default, push down the parent requirements to children - let children_len = requirements.plan.children().len(); - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![Some(parent_required.clone()); children_len], - }) - }; - Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) - } else { - Transformed::No(requirements) - }) + requirements.children.iter_mut().for_each(|child| { + child.required_key_ordering = parent_required.clone(); + }); + Ok(Transformed::Yes(requirements)) + } } fn reorder_partitioned_join_keys( @@ -454,28 +452,24 @@ where for idx in 0..sort_options.len() { new_sort_options.push(sort_options[new_positions[idx]]) } - - Ok(PlanWithKeyRequirements { - plan: join_constructor((new_join_on, new_sort_options))?, - required_key_ordering: vec![], - request_key_ordering: vec![Some(left_keys), Some(right_keys)], - }) + let mut requirement_tree = PlanWithKeyRequirements::new(join_constructor(( + new_join_on, + new_sort_options, + ))?); + requirement_tree.children[0].required_key_ordering = left_keys; + requirement_tree.children[1].required_key_ordering = right_keys; + Ok(requirement_tree) } else { - Ok(PlanWithKeyRequirements { - plan: join_plan, - required_key_ordering: vec![], - request_key_ordering: vec![Some(left_keys), Some(right_keys)], - }) + let mut requirement_tree = PlanWithKeyRequirements::new(join_plan); + requirement_tree.children[0].required_key_ordering = left_keys; + requirement_tree.children[1].required_key_ordering = right_keys; + Ok(requirement_tree) } } else { - Ok(PlanWithKeyRequirements { - plan: join_plan, - required_key_ordering: vec![], - request_key_ordering: vec![ - Some(join_key_pairs.left_keys), - Some(join_key_pairs.right_keys), - ], - }) + let mut requirement_tree = PlanWithKeyRequirements::new(join_plan); + requirement_tree.children[0].required_key_ordering = join_key_pairs.left_keys; + requirement_tree.children[1].required_key_ordering = join_key_pairs.right_keys; + Ok(requirement_tree) } } @@ -484,7 +478,7 @@ fn reorder_aggregate_keys( parent_required: &[Arc], agg_exec: &AggregateExec, ) -> Result { - let out_put_columns = agg_exec + let output_columns = agg_exec .group_by() .expr() .iter() @@ -492,50 +486,37 @@ fn reorder_aggregate_keys( .map(|(index, (_col, name))| Column::new(name, index)) .collect::>(); - let out_put_exprs = out_put_columns + let output_exprs = output_columns .iter() - .map(|c| Arc::new(c.clone()) as Arc) + .map(|c| Arc::new(c.clone()) as _) .collect::>(); - if parent_required.len() != out_put_exprs.len() + if parent_required.len() != output_exprs.len() || !agg_exec.group_by().null_expr().is_empty() - || expr_list_eq_strict_order(&out_put_exprs, parent_required) + || physical_exprs_equal(&output_exprs, parent_required) { Ok(PlanWithKeyRequirements::new(agg_plan)) } else { - let new_positions = expected_expr_positions(&out_put_exprs, parent_required); + let new_positions = expected_expr_positions(&output_exprs, parent_required); match new_positions { None => Ok(PlanWithKeyRequirements::new(agg_plan)), Some(positions) => { let new_partial_agg = if let Some(agg_exec) = agg_exec.input().as_any().downcast_ref::() - /*AggregateExec { - mode, - group_by, - aggr_expr, - filter_expr, - order_by_expr, - input, - input_schema, - .. - }) = - */ { if matches!(agg_exec.mode(), &AggregateMode::Partial) { - let mut new_group_exprs = vec![]; - for idx in positions.iter() { - new_group_exprs - .push(agg_exec.group_by().expr()[*idx].clone()); - } + let group_exprs = agg_exec.group_by().expr(); + let new_group_exprs = positions + .into_iter() + .map(|idx| group_exprs[idx].clone()) + .collect(); let new_partial_group_by = PhysicalGroupBy::new_single(new_group_exprs); - // new Partial AggregateExec Some(Arc::new(AggregateExec::try_new( AggregateMode::Partial, new_partial_group_by, agg_exec.aggr_expr().to_vec(), agg_exec.filter_expr().to_vec(), - agg_exec.order_by_expr().to_vec(), agg_exec.input().clone(), agg_exec.input_schema.clone(), )?)) @@ -547,18 +528,13 @@ fn reorder_aggregate_keys( }; if let Some(partial_agg) = new_partial_agg { // Build new group expressions that correspond to the output of partial_agg - let new_final_group: Vec> = - partial_agg.output_group_expr(); + let group_exprs = partial_agg.group_expr().expr(); + let new_final_group = partial_agg.output_group_expr(); let new_group_by = PhysicalGroupBy::new_single( new_final_group .iter() .enumerate() - .map(|(i, expr)| { - ( - expr.clone(), - partial_agg.group_expr().expr()[i].1.clone(), - ) - }) + .map(|(idx, expr)| (expr.clone(), group_exprs[idx].1.clone())) .collect(), ); @@ -567,35 +543,32 @@ fn reorder_aggregate_keys( new_group_by, agg_exec.aggr_expr().to_vec(), agg_exec.filter_expr().to_vec(), - agg_exec.order_by_expr().to_vec(), partial_agg, - agg_exec.input_schema().clone(), + agg_exec.input_schema(), )?); // Need to create a new projection to change the expr ordering back - let mut proj_exprs = out_put_columns + let agg_schema = new_final_agg.schema(); + let mut proj_exprs = output_columns .iter() .map(|col| { + let name = col.name(); ( Arc::new(Column::new( - col.name(), - new_final_agg.schema().index_of(col.name()).unwrap(), - )) - as Arc, - col.name().to_owned(), + name, + agg_schema.index_of(name).unwrap(), + )) as _, + name.to_owned(), ) }) .collect::>(); - let agg_schema = new_final_agg.schema(); let agg_fields = agg_schema.fields(); for (idx, field) in - agg_fields.iter().enumerate().skip(out_put_columns.len()) + agg_fields.iter().enumerate().skip(output_columns.len()) { - proj_exprs.push(( - Arc::new(Column::new(field.name().as_str(), idx)) - as Arc, - field.name().clone(), - )) + let name = field.name(); + proj_exprs + .push((Arc::new(Column::new(name, idx)) as _, name.clone())) } // TODO merge adjacent Projections if there are Ok(PlanWithKeyRequirements::new(Arc::new( @@ -613,15 +586,14 @@ fn shift_right_required( parent_required: &[Arc], left_columns_len: usize, ) -> Option>> { - let new_right_required: Vec> = parent_required + let new_right_required = parent_required .iter() .filter_map(|r| { if let Some(col) = r.as_any().downcast_ref::() { - if col.index() >= left_columns_len { - Some( - Arc::new(Column::new(col.name(), col.index() - left_columns_len)) - as Arc, - ) + let idx = col.index(); + if idx >= left_columns_len { + let result = Column::new(col.name(), idx - left_columns_len); + Some(Arc::new(result) as _) } else { None } @@ -632,11 +604,7 @@ fn shift_right_required( .collect::>(); // if the parent required are all comming from the right side, the requirements can be pushdown - if new_right_required.len() != parent_required.len() { - None - } else { - Some(new_right_required) - } + (new_right_required.len() == parent_required.len()).then_some(new_right_required) } /// When the physical planner creates the Joins, the ordering of join keys is from the original query. @@ -660,8 +628,8 @@ fn shift_right_required( /// In that case, the datasources/tables might be pre-partitioned and we can't adjust the key ordering of the datasources /// and then can't apply the Top-Down reordering process. pub(crate) fn reorder_join_keys_to_inputs( - plan: Arc, -) -> Result> { + plan: Arc, +) -> Result> { let plan_any = plan.as_any(); if let Some(HashJoinExec { left, @@ -674,41 +642,34 @@ pub(crate) fn reorder_join_keys_to_inputs( .. }) = plan_any.downcast_ref::() { - match mode { - PartitionMode::Partitioned => { - let join_key_pairs = extract_join_keys(on); - if let Some(( - JoinKeyPairs { - left_keys, - right_keys, - }, - new_positions, - )) = reorder_current_join_keys( - join_key_pairs, - Some(left.output_partitioning()), - Some(right.output_partitioning()), - &left.equivalence_properties(), - &right.equivalence_properties(), - ) { - if !new_positions.is_empty() { - let new_join_on = new_join_conditions(&left_keys, &right_keys); - Ok(Arc::new(HashJoinExec::try_new( - left.clone(), - right.clone(), - new_join_on, - filter.clone(), - join_type, - PartitionMode::Partitioned, - *null_equals_null, - )?)) - } else { - Ok(plan) - } - } else { - Ok(plan) + if matches!(mode, PartitionMode::Partitioned) { + let join_key_pairs = extract_join_keys(on); + if let Some(( + JoinKeyPairs { + left_keys, + right_keys, + }, + new_positions, + )) = reorder_current_join_keys( + join_key_pairs, + Some(left.output_partitioning()), + Some(right.output_partitioning()), + &left.equivalence_properties(), + &right.equivalence_properties(), + ) { + if !new_positions.is_empty() { + let new_join_on = new_join_conditions(&left_keys, &right_keys); + return Ok(Arc::new(HashJoinExec::try_new( + left.clone(), + right.clone(), + new_join_on, + filter.clone(), + join_type, + PartitionMode::Partitioned, + *null_equals_null, + )?)); } } - _ => Ok(plan), } } else if let Some(SortMergeJoinExec { left, @@ -736,27 +697,21 @@ pub(crate) fn reorder_join_keys_to_inputs( ) { if !new_positions.is_empty() { let new_join_on = new_join_conditions(&left_keys, &right_keys); - let mut new_sort_options = vec![]; - for idx in 0..sort_options.len() { - new_sort_options.push(sort_options[new_positions[idx]]) - } - Ok(Arc::new(SortMergeJoinExec::try_new( + let new_sort_options = (0..sort_options.len()) + .map(|idx| sort_options[new_positions[idx]]) + .collect(); + return Ok(Arc::new(SortMergeJoinExec::try_new( left.clone(), right.clone(), new_join_on, *join_type, new_sort_options, *null_equals_null, - )?)) - } else { - Ok(plan) + )?)); } - } else { - Ok(plan) } - } else { - Ok(plan) } + Ok(plan) } /// Reorder the current join keys ordering based on either left partition or right partition @@ -792,39 +747,40 @@ fn try_reorder( expected: &[Arc], equivalence_properties: &EquivalenceProperties, ) -> Option<(JoinKeyPairs, Vec)> { + let eq_groups = equivalence_properties.eq_group(); let mut normalized_expected = vec![]; let mut normalized_left_keys = vec![]; let mut normalized_right_keys = vec![]; if join_keys.left_keys.len() != expected.len() { return None; } - if expr_list_eq_strict_order(expected, &join_keys.left_keys) - || expr_list_eq_strict_order(expected, &join_keys.right_keys) + if physical_exprs_equal(expected, &join_keys.left_keys) + || physical_exprs_equal(expected, &join_keys.right_keys) { return Some((join_keys, vec![])); - } else if !equivalence_properties.classes().is_empty() { + } else if !equivalence_properties.eq_group().is_empty() { normalized_expected = expected .iter() - .map(|e| equivalence_properties.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(e.clone())) .collect::>(); assert_eq!(normalized_expected.len(), expected.len()); normalized_left_keys = join_keys .left_keys .iter() - .map(|e| equivalence_properties.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(e.clone())) .collect::>(); assert_eq!(join_keys.left_keys.len(), normalized_left_keys.len()); normalized_right_keys = join_keys .right_keys .iter() - .map(|e| equivalence_properties.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(e.clone())) .collect::>(); assert_eq!(join_keys.right_keys.len(), normalized_right_keys.len()); - if expr_list_eq_strict_order(&normalized_expected, &normalized_left_keys) - || expr_list_eq_strict_order(&normalized_expected, &normalized_right_keys) + if physical_exprs_equal(&normalized_expected, &normalized_left_keys) + || physical_exprs_equal(&normalized_expected, &normalized_right_keys) { return Some((join_keys, vec![])); } @@ -884,12 +840,7 @@ fn expected_expr_positions( fn extract_join_keys(on: &[(Column, Column)]) -> JoinKeyPairs { let (left_keys, right_keys) = on .iter() - .map(|(l, r)| { - ( - Arc::new(l.clone()) as Arc, - Arc::new(r.clone()) as Arc, - ) - }) + .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) .unzip(); JoinKeyPairs { left_keys, @@ -901,7 +852,7 @@ fn new_join_conditions( new_left_keys: &[Arc], new_right_keys: &[Arc], ) -> Vec<(Column, Column)> { - let new_join_on = new_left_keys + new_left_keys .iter() .zip(new_right_keys.iter()) .map(|(l_key, r_key)| { @@ -910,85 +861,44 @@ fn new_join_conditions( r_key.as_any().downcast_ref::().unwrap().clone(), ) }) - .collect::>(); - new_join_on -} - -/// Updates `dist_onward` such that, to keep track of -/// `input` in the `exec_tree`. -/// -/// # Arguments -/// -/// * `input`: Current execution plan -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until child of `input` (`input` should have single child). -/// * `input_idx`: index of the `input`, for its parent. -/// -fn update_distribution_onward( - input: Arc, - dist_onward: &mut Option, - input_idx: usize, -) { - // Update the onward tree if there is an active branch - if let Some(exec_tree) = dist_onward { - // When we add a new operator to change distribution - // we add RepartitionExec, SortPreservingMergeExec, CoalescePartitionsExec - // in this case, we need to update exec tree idx such that exec tree is now child of these - // operators (change the 0, since all of the operators have single child). - exec_tree.idx = 0; - *exec_tree = ExecTree::new(input, input_idx, vec![exec_tree.clone()]); - } else { - *dist_onward = Some(ExecTree::new(input, input_idx, vec![])); - } + .collect() } /// Adds RoundRobin repartition operator to the plan increase parallelism. /// /// # Arguments /// -/// * `input`: Current execution plan +/// * `input`: Current node. /// * `n_target`: desired target partition number, if partition number of the /// current executor is less than this value. Partition number will be increased. -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until `input` plan. -/// * `input_idx`: index of the `input`, for its parent. /// /// # Returns /// -/// A [Result] object that contains new execution plan, where desired partition number -/// is achieved by adding RoundRobin Repartition. +/// A [`Result`] object that contains new execution plan where the desired +/// partition number is achieved by adding a RoundRobin repartition. fn add_roundrobin_on_top( - input: Arc, + input: DistributionContext, n_target: usize, - dist_onward: &mut Option, - input_idx: usize, -) -> Result> { - // Adding repartition is helpful - if input.output_partitioning().partition_count() < n_target { +) -> Result { + // Adding repartition is helpful: + if input.plan.output_partitioning().partition_count() < n_target { // When there is an existing ordering, we preserve ordering // during repartition. This will be un-done in the future // If any of the following conditions is true // - Preserving ordering is not helpful in terms of satisfying ordering requirements // - Usage of order preserving variants is not desirable - // (determined by flag `config.optimizer.bounded_order_preserving_variants`) - let should_preserve_ordering = input.output_ordering().is_some(); - - let new_plan = Arc::new( - RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(n_target))? - .with_preserve_order(should_preserve_ordering), - ) as Arc; - if let Some(exec_tree) = dist_onward { - return internal_err!( - "ExecTree should have been empty, but got:{:?}", - exec_tree - ); - } + // (determined by flag `config.optimizer.prefer_existing_sort`) + let partitioning = Partitioning::RoundRobinBatch(n_target); + let repartition = RepartitionExec::try_new(input.plan.clone(), partitioning)? + .with_preserve_order(); + + let new_plan = Arc::new(repartition) as _; - // update distribution onward with new operator - update_distribution_onward(new_plan.clone(), dist_onward, input_idx); - Ok(new_plan) + Ok(DistributionContext { + plan: new_plan, + distribution_connection: true, + children_nodes: vec![input], + }) } else { // Partition is not helpful, we already have desired number of partitions. Ok(input) @@ -1002,118 +912,106 @@ fn add_roundrobin_on_top( /// /// # Arguments /// -/// * `input`: Current execution plan +/// * `input`: Current node. /// * `hash_exprs`: Stores Physical Exprs that are used during hashing. /// * `n_target`: desired target partition number, if partition number of the /// current executor is less than this value. Partition number will be increased. -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until `input` plan. -/// * `input_idx`: index of the `input`, for its parent. /// /// # Returns /// -/// A [Result] object that contains new execution plan, where desired distribution is -/// satisfied by adding Hash Repartition. +/// A [`Result`] object that contains new execution plan where the desired +/// distribution is satisfied by adding a Hash repartition. fn add_hash_on_top( - input: Arc, + mut input: DistributionContext, hash_exprs: Vec>, - // Repartition(Hash) will have `n_target` partitions at the output. n_target: usize, - // Stores executors starting from Repartition(RoundRobin) until - // current executor. When Repartition(Hash) is added, `dist_onward` - // is updated such that it stores connection from Repartition(RoundRobin) - // until Repartition(Hash). - dist_onward: &mut Option, - input_idx: usize, -) -> Result> { - if n_target == input.output_partitioning().partition_count() && n_target == 1 { - // In this case adding a hash repartition is unnecessary as the hash - // requirement is implicitly satisfied. + repartition_beneficial_stats: bool, +) -> Result { + // Early return if hash repartition is unnecessary + if n_target == 1 { return Ok(input); } + let satisfied = input + .plan .output_partitioning() .satisfy(Distribution::HashPartitioned(hash_exprs.clone()), || { - input.equivalence_properties() + input.plan.equivalence_properties() }); + // Add hash repartitioning when: // - The hash distribution requirement is not satisfied, or // - We can increase parallelism by adding hash partitioning. - if !satisfied || n_target > input.output_partitioning().partition_count() { + if !satisfied || n_target > input.plan.output_partitioning().partition_count() { // When there is an existing ordering, we preserve ordering during // repartition. This will be rolled back in the future if any of the // following conditions is true: // - Preserving ordering is not helpful in terms of satisfying ordering // requirements. // - Usage of order preserving variants is not desirable (per the flag - // `config.optimizer.bounded_order_preserving_variants`). - let should_preserve_ordering = input.output_ordering().is_some(); - // Since hashing benefits from partitioning, add a round-robin repartition - // before it: - let mut new_plan = add_roundrobin_on_top(input, n_target, dist_onward, 0)?; - new_plan = Arc::new( - RepartitionExec::try_new(new_plan, Partitioning::Hash(hash_exprs, n_target))? - .with_preserve_order(should_preserve_ordering), - ) as _; - - // update distribution onward with new operator - update_distribution_onward(new_plan.clone(), dist_onward, input_idx); - Ok(new_plan) - } else { - Ok(input) + // `config.optimizer.prefer_existing_sort`). + if repartition_beneficial_stats { + // Since hashing benefits from partitioning, add a round-robin repartition + // before it: + input = add_roundrobin_on_top(input, n_target)?; + } + + let partitioning = Partitioning::Hash(hash_exprs, n_target); + let repartition = RepartitionExec::try_new(input.plan.clone(), partitioning)? + .with_preserve_order(); + + input.children_nodes = vec![input.clone()]; + input.distribution_connection = true; + input.plan = Arc::new(repartition) as _; } + + Ok(input) } -/// Adds a `SortPreservingMergeExec` operator on top of input executor: -/// - to satisfy single distribution requirement. +/// Adds a [`SortPreservingMergeExec`] operator on top of input executor +/// to satisfy single distribution requirement. /// /// # Arguments /// -/// * `input`: Current execution plan -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until `input` plan. -/// * `input_idx`: index of the `input`, for its parent. +/// * `input`: Current node. /// /// # Returns /// -/// New execution plan, where desired single -/// distribution is satisfied by adding `SortPreservingMergeExec`. -fn add_spm_on_top( - input: Arc, - dist_onward: &mut Option, - input_idx: usize, -) -> Arc { +/// Updated node with an execution plan, where desired single +/// distribution is satisfied by adding [`SortPreservingMergeExec`]. +fn add_spm_on_top(input: DistributionContext) -> DistributionContext { // Add SortPreservingMerge only when partition count is larger than 1. - if input.output_partitioning().partition_count() > 1 { + if input.plan.output_partitioning().partition_count() > 1 { // When there is an existing ordering, we preserve ordering - // during decreasıng partıtıons. This will be un-done in the future - // If any of the following conditions is true + // when decreasing partitions. This will be un-done in the future + // if any of the following conditions is true // - Preserving ordering is not helpful in terms of satisfying ordering requirements // - Usage of order preserving variants is not desirable // (determined by flag `config.optimizer.bounded_order_preserving_variants`) - let should_preserve_ordering = input.output_ordering().is_some(); - let new_plan: Arc = if should_preserve_ordering { - let existing_ordering = input.output_ordering().unwrap_or(&[]); + let should_preserve_ordering = input.plan.output_ordering().is_some(); + + let new_plan = if should_preserve_ordering { Arc::new(SortPreservingMergeExec::new( - existing_ordering.to_vec(), - input, + input.plan.output_ordering().unwrap_or(&[]).to_vec(), + input.plan.clone(), )) as _ } else { - Arc::new(CoalescePartitionsExec::new(input)) as _ + Arc::new(CoalescePartitionsExec::new(input.plan.clone())) as _ }; - // update repartition onward with new operator - update_distribution_onward(new_plan.clone(), dist_onward, input_idx); - new_plan + DistributionContext { + plan: new_plan, + distribution_connection: true, + children_nodes: vec![input], + } } else { input } } -/// Updates the physical plan inside `distribution_context` if having a -/// `RepartitionExec(RoundRobin)` is not helpful. +/// Updates the physical plan inside [`DistributionContext`] so that distribution +/// changing operators are removed from the top. If they are necessary, they will +/// be added in subsequent stages. /// /// Assume that following plan is given: /// ```text @@ -1122,94 +1020,40 @@ fn add_spm_on_top( /// " ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", /// ``` /// -/// `RepartitionExec` at the top is unnecessary. Since it doesn't help with increasing parallelism. -/// This function removes top repartition, and returns following plan. +/// Since `RepartitionExec`s change the distribution, this function removes +/// them and returns following plan: /// /// ```text -/// "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", -/// " ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", +/// "ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", /// ``` -fn remove_unnecessary_repartition( - distribution_context: DistributionContext, +fn remove_dist_changing_operators( + mut distribution_context: DistributionContext, ) -> Result { - let DistributionContext { - mut plan, - mut distribution_onwards, - } = distribution_context; - - // Remove any redundant RoundRobin at the start: - if let Some(repartition) = plan.as_any().downcast_ref::() { - if let Partitioning::RoundRobinBatch(n_out) = repartition.partitioning() { - // Repartition is useless: - if *n_out <= repartition.input().output_partitioning().partition_count() { - let mut new_distribution_onwards = - vec![None; repartition.input().children().len()]; - if let Some(exec_tree) = &distribution_onwards[0] { - for child in &exec_tree.children { - new_distribution_onwards[child.idx] = Some(child.clone()); - } - } - plan = repartition.input().clone(); - distribution_onwards = new_distribution_onwards; - } - } + while is_repartition(&distribution_context.plan) + || is_coalesce_partitions(&distribution_context.plan) + || is_sort_preserving_merge(&distribution_context.plan) + { + // All of above operators have a single child. First child is only child. + let child = distribution_context.children_nodes.swap_remove(0); + // Remove any distribution changing operators at the beginning: + // Note that they will be re-inserted later on if necessary or helpful. + distribution_context = child; } - // Create a plan with the updated children: - Ok(DistributionContext { - plan, - distribution_onwards, - }) + Ok(distribution_context) } -/// Changes each child of the `dist_context.plan` such that they no longer -/// use order preserving variants, if no ordering is required at the output -/// of the physical plan (there is no global ordering requirement by the query). -fn update_plan_to_remove_unnecessary_final_order( - dist_context: DistributionContext, -) -> Result> { - let DistributionContext { - plan, - distribution_onwards, - } = dist_context; - let new_children = izip!(plan.children(), distribution_onwards) - .map(|(mut child, mut dist_onward)| { - replace_order_preserving_variants(&mut child, &mut dist_onward)?; - Ok(child) - }) - .collect::>>()?; - if !new_children.is_empty() { - plan.with_new_children(new_children) - } else { - Ok(plan) - } -} - -/// Updates the physical plan `input` by using `dist_onward` replace order preserving operator variants -/// with their corresponding operators that do not preserve order. It is a wrapper for `replace_order_preserving_variants_helper` -fn replace_order_preserving_variants( - input: &mut Arc, - dist_onward: &mut Option, -) -> Result<()> { - if let Some(dist_onward) = dist_onward { - *input = replace_order_preserving_variants_helper(dist_onward)?; - } - *dist_onward = None; - Ok(()) -} - -/// Updates the physical plan inside `ExecTree` if preserving ordering while changing partitioning -/// is not helpful or desirable. +/// Updates the [`DistributionContext`] if preserving ordering while changing partitioning is not helpful or desirable. /// /// Assume that following plan is given: /// ```text /// "SortPreservingMergeExec: \[a@0 ASC]" -/// " SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", -/// " SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", +/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true", +/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true", /// " ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", /// ``` /// -/// This function converts plan above (inside `ExecTree`) to the following: +/// This function converts plan above to the following: /// /// ```text /// "CoalescePartitionsExec" @@ -1217,32 +1061,75 @@ fn replace_order_preserving_variants( /// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", /// " ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", /// ``` -fn replace_order_preserving_variants_helper( - exec_tree: &ExecTree, -) -> Result> { - let mut updated_children = exec_tree.plan.children(); - for child in &exec_tree.children { - updated_children[child.idx] = replace_order_preserving_variants_helper(child)?; +fn replace_order_preserving_variants( + mut context: DistributionContext, +) -> Result { + let mut updated_children = context + .children_nodes + .iter() + .map(|child| { + if child.distribution_connection { + replace_order_preserving_variants(child.clone()) + } else { + Ok(child.clone()) + } + }) + .collect::>>()?; + + if is_sort_preserving_merge(&context.plan) { + let child = updated_children.swap_remove(0); + context.plan = Arc::new(CoalescePartitionsExec::new(child.plan.clone())); + context.children_nodes = vec![child]; + return Ok(context); + } else if let Some(repartition) = + context.plan.as_any().downcast_ref::() + { + if repartition.preserve_order() { + let child = updated_children.swap_remove(0); + context.plan = Arc::new(RepartitionExec::try_new( + child.plan.clone(), + repartition.partitioning().clone(), + )?); + context.children_nodes = vec![child]; + return Ok(context); + } } - if let Some(spm) = exec_tree + + context.plan = context + .plan + .clone() + .with_new_children(updated_children.into_iter().map(|c| c.plan).collect())?; + Ok(context) +} + +/// This utility function adds a [`SortExec`] above an operator according to the +/// given ordering requirements while preserving the original partitioning. +fn add_sort_preserving_partitions( + node: DistributionContext, + sort_requirement: LexRequirementRef, + fetch: Option, +) -> DistributionContext { + // If the ordering requirement is already satisfied, do not add a sort. + if !node .plan - .as_any() - .downcast_ref::() + .equivalence_properties() + .ordering_satisfy_requirement(sort_requirement) { - return Ok(Arc::new(CoalescePartitionsExec::new(spm.input().clone()))); - } - if let Some(repartition) = exec_tree.plan.as_any().downcast_ref::() { - if repartition.preserve_order() { - return Ok(Arc::new( - RepartitionExec::try_new( - repartition.input().clone(), - repartition.partitioning().clone(), - )? - .with_preserve_order(false), - )); + let sort_expr = PhysicalSortRequirement::to_sort_exprs(sort_requirement.to_vec()); + let new_sort = SortExec::new(sort_expr, node.plan.clone()).with_fetch(fetch); + + DistributionContext { + plan: Arc::new(if node.plan.output_partitioning().partition_count() > 1 { + new_sort.with_preserve_partitioning(true) + } else { + new_sort + }), + distribution_connection: false, + children_nodes: vec![node], } + } else { + node } - exec_tree.plan.clone().with_new_children(updated_children) } /// This function checks whether we need to add additional data exchange @@ -1253,111 +1140,108 @@ fn ensure_distribution( dist_context: DistributionContext, config: &ConfigOptions, ) -> Result> { + let dist_context = dist_context.update_children()?; + + if dist_context.plan.children().is_empty() { + return Ok(Transformed::No(dist_context)); + } + let target_partitions = config.execution.target_partitions; // When `false`, round robin repartition will not be added to increase parallelism let enable_round_robin = config.optimizer.enable_round_robin_repartition; let repartition_file_scans = config.optimizer.repartition_file_scans; - let repartition_file_min_size = config.optimizer.repartition_file_min_size; + let batch_size = config.execution.batch_size; let is_unbounded = unbounded_output(&dist_context.plan); // Use order preserving variants either of the conditions true // - it is desired according to config // - when plan is unbounded let order_preserving_variants_desirable = - is_unbounded || config.optimizer.bounded_order_preserving_variants; - - if dist_context.plan.children().is_empty() { - return Ok(Transformed::No(dist_context)); - } - // Don't need to apply when the returned row count is not greater than 1: - let stats = dist_context.plan.statistics(); - let mut repartition_beneficial_stat = true; - if stats.is_exact { - repartition_beneficial_stat = - stats.num_rows.map(|num_rows| num_rows > 1).unwrap_or(true); - } + is_unbounded || config.optimizer.prefer_existing_sort; // Remove unnecessary repartition from the physical plan if any let DistributionContext { - plan, - mut distribution_onwards, - } = remove_unnecessary_repartition(dist_context)?; + mut plan, + distribution_connection, + children_nodes, + } = remove_dist_changing_operators(dist_context)?; + + if let Some(exec) = plan.as_any().downcast_ref::() { + if let Some(updated_window) = get_best_fitting_window( + exec.window_expr(), + exec.input(), + &exec.partition_keys, + )? { + plan = updated_window; + } + } else if let Some(exec) = plan.as_any().downcast_ref::() { + if let Some(updated_window) = get_best_fitting_window( + exec.window_expr(), + exec.input(), + &exec.partition_keys, + )? { + plan = updated_window; + } + }; - let n_children = plan.children().len(); // This loop iterates over all the children to: // - Increase parallelism for every child if it is beneficial. // - Satisfy the distribution requirements of every child, if it is not // already satisfied. // We store the updated children in `new_children`. - let new_children = izip!( - plan.children().into_iter(), + let children_nodes = izip!( + children_nodes.into_iter(), plan.required_input_distribution().iter(), plan.required_input_ordering().iter(), - distribution_onwards.iter_mut(), plan.benefits_from_input_partitioning(), - plan.maintains_input_order(), - 0..n_children + plan.maintains_input_order() ) .map( - |( - mut child, - requirement, - required_input_ordering, - dist_onward, - would_benefit, - maintains, - child_idx, - )| { + |(mut child, requirement, required_input_ordering, would_benefit, maintains)| { + // Don't need to apply when the returned row count is not greater than batch size + let num_rows = child.plan.statistics()?.num_rows; + let repartition_beneficial_stats = if num_rows.is_exact().unwrap_or(false) { + num_rows + .get_value() + .map(|value| value > &batch_size) + .unwrap() // safe to unwrap since is_exact() is true + } else { + true + }; + + // When `repartition_file_scans` is set, attempt to increase + // parallelism at the source. + if repartition_file_scans && repartition_beneficial_stats { + if let Some(new_child) = + child.plan.repartitioned(target_partitions, config)? + { + child.plan = new_child; + } + } + if enable_round_robin // Operator benefits from partitioning (e.g. filter): - && (would_benefit && repartition_beneficial_stat) + && (would_benefit && repartition_beneficial_stats) // Unless partitioning doesn't increase the partition count, it is not beneficial: - && child.output_partitioning().partition_count() < target_partitions + && child.plan.output_partitioning().partition_count() < target_partitions { - // When `repartition_file_scans` is set, leverage source operators - // (`ParquetExec`, `CsvExec` etc.) to increase parallelism at the source. - if repartition_file_scans { - if let Some(parquet_exec) = - child.as_any().downcast_ref::() - { - child = Arc::new(parquet_exec.get_repartitioned( - target_partitions, - repartition_file_min_size, - )); - } else if let Some(csv_exec) = - child.as_any().downcast_ref::() - { - if let Some(csv_exec) = csv_exec.get_repartitioned( - target_partitions, - repartition_file_min_size, - ) { - child = Arc::new(csv_exec); - } - } - } // Increase parallelism by adding round-robin repartitioning // on top of the operator. Note that we only do this if the // partition count is not already equal to the desired partition // count. - child = add_roundrobin_on_top( - child, - target_partitions, - dist_onward, - child_idx, - )?; + child = add_roundrobin_on_top(child, target_partitions)?; } // Satisfy the distribution requirement if it is unmet. match requirement { Distribution::SinglePartition => { - child = add_spm_on_top(child, dist_onward, child_idx); + child = add_spm_on_top(child); } Distribution::HashPartitioned(exprs) => { child = add_hash_on_top( child, exprs.to_vec(), target_partitions, - dist_onward, - child_idx, + repartition_beneficial_stats, )?; } Distribution::UnspecifiedDistribution => {} @@ -1365,38 +1249,42 @@ fn ensure_distribution( // There is an ordering requirement of the operator: if let Some(required_input_ordering) = required_input_ordering { - let existing_ordering = child.output_ordering().unwrap_or(&[]); // Either: // - Ordering requirement cannot be satisfied by preserving ordering through repartitions, or // - using order preserving variant is not desirable. - if !ordering_satisfy_requirement_concrete( - existing_ordering, - required_input_ordering, - || child.equivalence_properties(), - || child.ordering_equivalence_properties(), - ) || !order_preserving_variants_desirable + let ordering_satisfied = child + .plan + .equivalence_properties() + .ordering_satisfy_requirement(required_input_ordering); + if (!ordering_satisfied || !order_preserving_variants_desirable) + && child.distribution_connection { - replace_order_preserving_variants(&mut child, dist_onward)?; - let sort_expr = PhysicalSortRequirement::to_sort_exprs( - required_input_ordering.clone(), - ); - // Make sure to satisfy ordering requirement - add_sort_above(&mut child, sort_expr, None)?; + child = replace_order_preserving_variants(child)?; + // If ordering requirements were satisfied before repartitioning, + // make sure ordering requirements are still satisfied after. + if ordering_satisfied { + // Make sure to satisfy ordering requirement: + child = add_sort_preserving_partitions( + child, + required_input_ordering, + None, + ); + } } // Stop tracking distribution changing operators - *dist_onward = None; + child.distribution_connection = false; } else { // no ordering requirement match requirement { // Operator requires specific distribution. Distribution::SinglePartition | Distribution::HashPartitioned(_) => { // Since there is no ordering requirement, preserving ordering is pointless - replace_order_preserving_variants(&mut child, dist_onward)?; + child = replace_order_preserving_variants(child)?; } Distribution::UnspecifiedDistribution => { // Since ordering is lost, trying to preserve ordering is pointless - if !maintains { - replace_order_preserving_variants(&mut child, dist_onward)?; + if !maintains || plan.as_any().is::() { + child = replace_order_preserving_variants(child)?; } } } @@ -1407,7 +1295,9 @@ fn ensure_distribution( .collect::>>()?; let new_distribution_context = DistributionContext { - plan: if plan.as_any().is::() && can_interleave(&new_children) { + plan: if plan.as_any().is::() + && can_interleave(children_nodes.iter().map(|c| c.plan.clone())) + { // Add a special case for [`UnionExec`] since we want to "bubble up" // hash-partitioned data. So instead of // @@ -1431,152 +1321,109 @@ fn ensure_distribution( // - Agg: // Repartition (hash): // Data - Arc::new(InterleaveExec::try_new(new_children)?) + Arc::new(InterleaveExec::try_new( + children_nodes.iter().map(|c| c.plan.clone()).collect(), + )?) } else { - plan.clone().with_new_children(new_children)? + plan.with_new_children( + children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? }, - distribution_onwards, + distribution_connection, + children_nodes, }; + Ok(Transformed::Yes(new_distribution_context)) } -/// A struct to keep track of distribution changing executors +/// A struct to keep track of distribution changing operators /// (`RepartitionExec`, `SortPreservingMergeExec`, `CoalescePartitionsExec`), /// and their associated parents inside `plan`. Using this information, /// we can optimize distribution of the plan if/when necessary. #[derive(Debug, Clone)] struct DistributionContext { plan: Arc, - /// Keep track of associations for each child of the plan. If `None`, - /// there is no distribution changing operator in its descendants. - distribution_onwards: Vec>, + /// Indicates whether this plan is connected to a distribution-changing + /// operator. + distribution_connection: bool, + children_nodes: Vec, } impl DistributionContext { - /// Creates an empty context. + /// Creates a tree according to the plan with empty states. fn new(plan: Arc) -> Self { - let length = plan.children().len(); - DistributionContext { + let children = plan.children(); + Self { plan, - distribution_onwards: vec![None; length], + distribution_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } - /// Constructs a new context from children contexts. - fn new_from_children_nodes( - children_nodes: Vec, - parent_plan: Arc, - ) -> Result { - let children_plans = children_nodes - .iter() - .map(|item| item.plan.clone()) - .collect(); - let distribution_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, context)| { - let DistributionContext { - plan, - // The `distribution_onwards` tree keeps track of operators - // that change distribution, or preserves the existing - // distribution (starting from an operator that change distribution). - distribution_onwards, - } = context; - if plan.children().is_empty() { - // Plan has no children, there is nothing to propagate. - None - } else if distribution_onwards[0].is_none() { - if let Some(repartition) = - plan.as_any().downcast_ref::() - { - match repartition.partitioning() { - Partitioning::RoundRobinBatch(_) - | Partitioning::Hash(_, _) => { - // Start tracking operators starting from this repartition (either roundrobin or hash): - return Some(ExecTree::new(plan, idx, vec![])); - } - _ => {} - } - } else if plan.as_any().is::() - || plan.as_any().is::() - { - // Start tracking operators starting from this sort preserving merge: - return Some(ExecTree::new(plan, idx, vec![])); - } - None - } else { - // Propagate children distribution tracking to the above - let new_distribution_onwards = izip!( - plan.required_input_distribution().iter(), - distribution_onwards.into_iter() - ) - .flat_map(|(required_dist, distribution_onwards)| { - if let Some(distribution_onwards) = distribution_onwards { - // Operator can safely propagate the distribution above. - // This is similar to maintaining order in the EnforceSorting rule. - if let Distribution::UnspecifiedDistribution = required_dist { - return Some(distribution_onwards); - } - } - None - }) - .collect::>(); - // Either: - // - None of the children has a connection to an operator that modifies distribution, or - // - The current operator requires distribution at its input so doesn't propagate it above. - if new_distribution_onwards.is_empty() { - None - } else { - Some(ExecTree::new(plan, idx, new_distribution_onwards)) - } + fn update_children(mut self) -> Result { + for child_context in self.children_nodes.iter_mut() { + child_context.distribution_connection = match &child_context.plan { + plan if is_repartition(plan) + || is_coalesce_partitions(plan) + || is_sort_preserving_merge(plan) => + { + true } - }) - .collect(); - Ok(DistributionContext { - plan: with_new_children_if_necessary(parent_plan, children_plans)?.into(), - distribution_onwards, - }) - } + _ => { + child_context.plan.children().is_empty() + || child_context.children_nodes[0].distribution_connection + || child_context + .plan + .required_input_distribution() + .iter() + .zip(child_context.children_nodes.iter()) + .any(|(required_dist, child_context)| { + child_context.distribution_connection + && matches!( + required_dist, + Distribution::UnspecifiedDistribution + ) + }) + } + }; + } - /// Computes distribution tracking contexts for every child of the plan. - fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(|child| DistributionContext::new(child)) - .collect() + let children_plans = self + .children_nodes + .iter() + .map(|context| context.plan.clone()) + .collect::>(); + + Ok(Self { + plan: with_new_children_if_necessary(self.plan, children_plans)?.into(), + distribution_connection: false, + children_nodes: self.children_nodes, + }) } } impl TreeNode for DistributionContext { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - DistributionContext::new_from_children_nodes(children_nodes, self.plan) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } @@ -1585,11 +1432,11 @@ impl fmt::Display for DistributionContext { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let plan_string = get_plan_string(&self.plan); write!(f, "plan: {:?}", plan_string)?; - for (idx, child) in self.distribution_onwards.iter().enumerate() { - if let Some(child) = child { - write!(f, "idx:{:?}, exec_tree:{}", idx, child)?; - } - } + write!( + f, + "distribution_connection:{}", + self.distribution_connection, + )?; write!(f, "") } } @@ -1605,91 +1452,59 @@ struct PlanWithKeyRequirements { plan: Arc, /// Parent required key ordering required_key_ordering: Vec>, - /// The request key ordering to children - request_key_ordering: Vec>>>, + children: Vec, } impl PlanWithKeyRequirements { fn new(plan: Arc) -> Self { - let children_len = plan.children().len(); - PlanWithKeyRequirements { + let children = plan.children(); + Self { plan, required_key_ordering: vec![], - request_key_ordering: vec![None; children_len], + children: children.into_iter().map(Self::new).collect(), } } - - fn children(&self) -> Vec { - let plan_children = self.plan.children(); - assert_eq!(plan_children.len(), self.request_key_ordering.len()); - plan_children - .into_iter() - .zip(self.request_key_ordering.clone()) - .map(|(child, required)| { - let from_parent = required.unwrap_or_default(); - let length = child.children().len(); - PlanWithKeyRequirements { - plan: child, - required_key_ordering: from_parent.clone(), - request_key_ordering: vec![None; length], - } - }) - .collect() - } } impl TreeNode for PlanWithKeyRequirements { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children.iter().map(Cow::Borrowed).collect() } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if !children.is_empty() { - let new_children: Result> = - children.into_iter().map(transform).collect(); - - let children_plans = new_children? + if !self.children.is_empty() { + self.children = self + .children .into_iter() - .map(|child| child.plan) - .collect::>(); - let new_plan = with_new_children_if_necessary(self.plan, children_plans)?; - Ok(PlanWithKeyRequirements { - plan: new_plan.into(), - required_key_ordering: self.required_key_ordering, - request_key_ordering: self.request_key_ordering, - }) - } else { - Ok(self) + .map(transform) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } +/// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on +#[cfg(feature = "parquet")] #[cfg(test)] -mod tests { +pub(crate) mod tests { use std::ops::Deref; use super::*; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::listing::PartitionedFile; use crate::datasource::object_store::ObjectStoreUrl; - use crate::datasource::physical_plan::{FileScanConfig, ParquetExec}; + use crate::datasource::physical_plan::ParquetExec; + use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; use crate::physical_optimizer::enforce_sorting::EnforceSorting; + use crate::physical_optimizer::output_requirements::OutputRequirements; use crate::physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; @@ -1703,9 +1518,12 @@ mod tests { use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::{displayable, DisplayAs, DisplayFormatType, Statistics}; - use crate::physical_optimizer::test_utils::repartition_exec; + use crate::physical_optimizer::test_utils::{ + coalesce_partitions_exec, repartition_exec, + }; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::sorts::sort::SortExec; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::ScalarValue; @@ -1714,7 +1532,7 @@ mod tests { use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; use datafusion_physical_expr::{ expressions, expressions::binary, expressions::lit, expressions::Column, - PhysicalExpr, PhysicalSortExpr, + LexOrdering, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; /// Models operators like BoundedWindowExec that require an input @@ -1722,11 +1540,23 @@ mod tests { #[derive(Debug)] struct SortRequiredExec { input: Arc, + expr: LexOrdering, } impl SortRequiredExec { fn new(input: Arc) -> Self { - Self { input } + let expr = input.output_ordering().unwrap_or(&[]).to_vec(); + Self { input, expr } + } + + fn new_with_requirement( + input: Arc, + requirement: Vec, + ) -> Self { + Self { + input, + expr: requirement, + } } } @@ -1736,7 +1566,11 @@ mod tests { _t: DisplayFormatType, f: &mut std::fmt::Formatter, ) -> std::fmt::Result { - write!(f, "SortRequiredExec") + write!( + f, + "SortRequiredExec: [{}]", + PhysicalSortExpr::format_list(&self.expr) + ) } } @@ -1778,7 +1612,10 @@ mod tests { ) -> Result> { assert_eq!(children.len(), 1); let child = children.pop().unwrap(); - Ok(Arc::new(Self::new(child))) + Ok(Arc::new(Self::new_with_requirement( + child, + self.expr.clone(), + ))) } fn execute( @@ -1789,12 +1626,12 @@ mod tests { unreachable!(); } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { self.input.statistics() } } - fn schema() -> SchemaRef { + pub(crate) fn schema() -> SchemaRef { Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, true), Field::new("b", DataType::Int64, true), @@ -1808,7 +1645,8 @@ mod tests { parquet_exec_with_sort(vec![]) } - fn parquet_exec_with_sort( + /// create a single parquet file that is sorted + pub(crate) fn parquet_exec_with_sort( output_ordering: Vec>, ) -> Arc { Arc::new(ParquetExec::new( @@ -1816,12 +1654,11 @@ mod tests { object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), file_schema: schema(), file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), + statistics: Statistics::new_unknown(&schema()), projection: None, limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, None, None, @@ -1832,7 +1669,7 @@ mod tests { parquet_exec_multiple_sorted(vec![]) } - // Created a sorted parquet exec with multiple files + /// Created a sorted parquet exec with multiple files fn parquet_exec_multiple_sorted( output_ordering: Vec>, ) -> Arc { @@ -1844,12 +1681,11 @@ mod tests { vec![PartitionedFile::new("x".to_string(), 100)], vec![PartitionedFile::new("y".to_string(), 100)], ], - statistics: Statistics::default(), + statistics: Statistics::new_unknown(&schema()), projection: None, limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, None, None, @@ -1866,12 +1702,11 @@ mod tests { object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), file_schema: schema(), file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), + statistics: Statistics::new_unknown(&schema()), projection: None, limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, false, b',', @@ -1897,12 +1732,11 @@ mod tests { vec![PartitionedFile::new("x".to_string(), 100)], vec![PartitionedFile::new("y".to_string(), 100)], ], - statistics: Statistics::default(), + statistics: Statistics::new_unknown(&schema()), projection: None, limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, false, b',', @@ -1954,14 +1788,12 @@ mod tests { final_grouping, vec![], vec![], - vec![], Arc::new( AggregateExec::try_new( AggregateMode::Partial, group_by, vec![], vec![], - vec![], input, schema.clone(), ) @@ -2054,7 +1886,14 @@ mod tests { Arc::new(SortRequiredExec::new(input)) } - fn trim_plan_display(plan: &str) -> Vec<&str> { + fn sort_required_exec_with_req( + input: Arc, + sort_exprs: LexOrdering, + ) -> Arc { + Arc::new(SortRequiredExec::new_with_requirement(input, sort_exprs)) + } + + pub(crate) fn trim_plan_display(plan: &str) -> Vec<&str> { plan.split('\n') .map(|s| s.trim()) .filter(|s| !s.is_empty()) @@ -2064,7 +1903,7 @@ mod tests { fn ensure_distribution_helper( plan: Arc, target_partitions: usize, - bounded_order_preserving_variants: bool, + prefer_existing_sort: bool, ) -> Result> { let distribution_context = DistributionContext::new(plan); let mut config = ConfigOptions::new(); @@ -2072,8 +1911,7 @@ mod tests { config.optimizer.enable_round_robin_repartition = false; config.optimizer.repartition_file_scans = false; config.optimizer.repartition_file_min_size = 1024; - config.optimizer.bounded_order_preserving_variants = - bounded_order_preserving_variants; + config.optimizer.prefer_existing_sort = prefer_existing_sort; ensure_distribution(distribution_context, &config).map(|item| item.into().plan) } @@ -2095,33 +1933,48 @@ mod tests { } /// Runs the repartition optimizer and asserts the plan against the expected + /// Arguments + /// * `EXPECTED_LINES` - Expected output plan + /// * `PLAN` - Input plan + /// * `FIRST_ENFORCE_DIST` - + /// true: (EnforceDistribution, EnforceDistribution, EnforceSorting) + /// false: else runs (EnforceSorting, EnforceDistribution, EnforceDistribution) + /// * `PREFER_EXISTING_SORT` (optional) - if true, will not repartition / resort data if it is already sorted + /// * `TARGET_PARTITIONS` (optional) - number of partitions to repartition to + /// * `REPARTITION_FILE_SCANS` (optional) - if true, will repartition file scans + /// * `REPARTITION_FILE_MIN_SIZE` (optional) - minimum file size to repartition macro_rules! assert_optimized { ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr) => { assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, false, 10, false, 1024); }; - ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $BOUNDED_ORDER_PRESERVING_VARIANTS: expr) => { - assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, $BOUNDED_ORDER_PRESERVING_VARIANTS, 10, false, 1024); + ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $PREFER_EXISTING_SORT: expr) => { + assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, $PREFER_EXISTING_SORT, 10, false, 1024); }; - ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $BOUNDED_ORDER_PRESERVING_VARIANTS: expr, $TARGET_PARTITIONS: expr, $REPARTITION_FILE_SCANS: expr, $REPARTITION_FILE_MIN_SIZE: expr) => { + ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $PREFER_EXISTING_SORT: expr, $TARGET_PARTITIONS: expr, $REPARTITION_FILE_SCANS: expr, $REPARTITION_FILE_MIN_SIZE: expr) => { let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect(); let mut config = ConfigOptions::new(); config.execution.target_partitions = $TARGET_PARTITIONS; config.optimizer.repartition_file_scans = $REPARTITION_FILE_SCANS; config.optimizer.repartition_file_min_size = $REPARTITION_FILE_MIN_SIZE; - config.optimizer.bounded_order_preserving_variants = $BOUNDED_ORDER_PRESERVING_VARIANTS; + config.optimizer.prefer_existing_sort = $PREFER_EXISTING_SORT; // NOTE: These tests verify the joint `EnforceDistribution` + `EnforceSorting` cascade // because they were written prior to the separation of `BasicEnforcement` into // `EnforceSorting` and `EnforceDistribution`. // TODO: Orthogonalize the tests here just to verify `EnforceDistribution` and create // new tests for the cascade. + + // Add the ancillary output requirements operator at the start: + let optimizer = OutputRequirements::new_add_mode(); + let optimized = optimizer.optimize($PLAN.clone(), &config)?; + let optimized = if $FIRST_ENFORCE_DIST { // Run enforce distribution rule first: let optimizer = EnforceDistribution::new(); - let optimized = optimizer.optimize($PLAN.clone(), &config)?; + let optimized = optimizer.optimize(optimized, &config)?; // The rule should be idempotent. // Re-running this rule shouldn't introduce unnecessary operators. let optimizer = EnforceDistribution::new(); @@ -2133,7 +1986,7 @@ mod tests { } else { // Run the enforce sorting rule first: let optimizer = EnforceSorting::new(); - let optimized = optimizer.optimize($PLAN.clone(), &config)?; + let optimized = optimizer.optimize(optimized, &config)?; // Run enforce distribution rule: let optimizer = EnforceDistribution::new(); let optimized = optimizer.optimize(optimized, &config)?; @@ -2144,6 +1997,10 @@ mod tests { optimized }; + // Remove the ancillary output requirements operator when done: + let optimizer = OutputRequirements::new_remove_mode(); + let optimized = optimizer.optimize(optimized, &config)?; + // Now format correctly let plan = displayable(optimized.as_ref()).indent(true).to_string(); let actual_lines = trim_plan_display(&plan); @@ -2978,7 +2835,7 @@ mod tests { format!("SortMergeJoin: join_type={join_type}, on=[(a@0, c@2)]"); let expected = match join_type { - // Should include 6 RepartitionExecs 3 SortExecs + // Should include 6 RepartitionExecs (3 hash, 3 round-robin), 3 SortExecs JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![ top_join_plan.as_str(), @@ -2997,9 +2854,18 @@ mod tests { "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ], - // Should include 7 RepartitionExecs + // Should include 7 RepartitionExecs (4 hash, 3 round-robin), 4 SortExecs + // Since ordering of the left child is not preserved after SortMergeJoin + // when mode is Right, RgihtSemi, RightAnti, Full + // - We need to add one additional SortExec after SortMergeJoin in contrast the test cases + // when mode is Inner, Left, LeftSemi, LeftAnti + // Similarly, since partitioning of the left side is not preserved + // when mode is Right, RgihtSemi, RightAnti, Full + // - We need to add one additional Hash Repartition after SortMergeJoin in contrast the test + // cases when mode is Inner, Left, LeftSemi, LeftAnti _ => vec![ top_join_plan.as_str(), + // Below 2 operators are differences introduced, when join mode is changed "SortExec: expr=[a@0 ASC]", "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", join_plan.as_str(), @@ -3021,43 +2887,52 @@ mod tests { assert_optimized!(expected, top_join.clone(), true, true); let expected_first_sort_enforcement = match join_type { - // Should include 3 RepartitionExecs 3 SortExecs + // Should include 6 RepartitionExecs (3 hash, 3 round-robin), 3 SortExecs JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![ top_join_plan.as_str(), join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ], - // Should include 8 RepartitionExecs (4 of them preserves ordering) + // Should include 8 RepartitionExecs (4 hash, 8 round-robin), 4 SortExecs + // Since ordering of the left child is not preserved after SortMergeJoin + // when mode is Right, RgihtSemi, RightAnti, Full + // - We need to add one additional SortExec after SortMergeJoin in contrast the test cases + // when mode is Inner, Left, LeftSemi, LeftAnti + // Similarly, since partitioning of the left side is not preserved + // when mode is Right, RgihtSemi, RightAnti, Full + // - We need to add one additional Hash Repartition and Roundrobin repartition after + // SortMergeJoin in contrast the test cases when mode is Inner, Left, LeftSemi, LeftAnti _ => vec![ top_join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + // Below 4 operators are differences introduced, when join mode is changed + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "CoalescePartitionsExec", join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -3083,7 +2958,7 @@ mod tests { format!("SortMergeJoin: join_type={join_type}, on=[(b1@6, c@2)]"); let expected = match join_type { - // Should include 3 RepartitionExecs and 3 SortExecs + // Should include 6 RepartitionExecs(3 hash, 3 round-robin) and 3 SortExecs JoinType::Inner | JoinType::Right => vec![ top_join_plan.as_str(), join_plan.as_str(), @@ -3101,8 +2976,8 @@ mod tests { "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ], - // Should include 4 RepartitionExecs and 4 SortExecs - _ => vec![ + // Should include 7 RepartitionExecs (4 hash, 3 round-robin) and 4 SortExecs + JoinType::Left | JoinType::Full => vec![ top_join_plan.as_str(), "SortExec: expr=[b1@6 ASC]", "RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10", @@ -3121,6 +2996,8 @@ mod tests { "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ], + // this match arm cannot be reached + _ => unreachable!() }; assert_optimized!(expected, top_join.clone(), true, true); @@ -3129,42 +3006,44 @@ mod tests { JoinType::Inner | JoinType::Right => vec![ top_join_plan.as_str(), join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ], // Should include 8 RepartitionExecs (4 of them preserves order) and 4 SortExecs - _ => vec![ + JoinType::Left | JoinType::Full => vec![ top_join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10", + "RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@6 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@6 ASC]", "CoalescePartitionsExec", join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ], + // this match arm cannot be reached + _ => unreachable!() }; assert_optimized!( expected_first_sort_enforcement, @@ -3249,7 +3128,7 @@ mod tests { let expected_first_sort_enforcement = &[ "SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)]", - "SortPreservingRepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10", + "RepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b3@1 ASC,a3@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b3@1 ASC,a3@0 ASC]", "CoalescePartitionsExec", @@ -3260,7 +3139,7 @@ mod tests { "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10", + "RepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b2@1 ASC,a2@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b2@1 ASC,a2@0 ASC]", "CoalescePartitionsExec", @@ -3301,6 +3180,16 @@ mod tests { "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", ]; assert_optimized!(expected, exec, true); + // In this case preserving ordering through order preserving operators is not desirable + // (according to flag: PREFER_EXISTING_SORT) + // hence in this case ordering lost during CoalescePartitionsExec and re-introduced with + // SortExec at the top. + let expected = &[ + "SortExec: expr=[a@0 ASC]", + "CoalescePartitionsExec", + "CoalesceBatchesExec: target_batch_size=4096", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; assert_optimized!(expected, exec, false); Ok(()) } @@ -3384,6 +3273,7 @@ mod tests { } #[test] + fn repartition_unsorted_limit() -> Result<()> { let plan = limit_exec(filter_exec(parquet_exec())); @@ -3435,7 +3325,7 @@ mod tests { sort_required_exec(filter_exec(sort_exec(sort_key, parquet_exec(), false))); let expected = &[ - "SortRequiredExec", + "SortRequiredExec: [c@2 ASC]", "FilterExec: c@2 = 0", // We can use repartition here, ordering requirement by SortRequiredExec // is still satisfied. @@ -3541,6 +3431,12 @@ mod tests { ]; assert_optimized!(expected, plan.clone(), true); + + let expected = &[ + "SortExec: expr=[c@2 ASC]", + "CoalescePartitionsExec", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; assert_optimized!(expected, plan, false); Ok(()) } @@ -3565,6 +3461,14 @@ mod tests { ]; assert_optimized!(expected, plan.clone(), true); + + let expected = &[ + "SortExec: expr=[c@2 ASC]", + "CoalescePartitionsExec", + "UnionExec", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; assert_optimized!(expected, plan, false); Ok(()) } @@ -3583,7 +3487,7 @@ mod tests { // during repartitioning ordering is preserved let expected = &[ - "SortRequiredExec", + "SortRequiredExec: [c@2 ASC]", "FilterExec: c@2 = 0", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", @@ -3619,7 +3523,7 @@ mod tests { let expected = &[ "UnionExec", // union input 1: no repartitioning - "SortRequiredExec", + "SortRequiredExec: [c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", // union input 2: should repartition "FilterExec: c@2 = 0", @@ -3687,16 +3591,13 @@ mod tests { ("c".to_string(), "c".to_string()), ]; // sorted input - let plan = sort_preserving_merge_exec( - sort_key.clone(), - projection_exec_with_alias( - parquet_exec_multiple_sorted(vec![sort_key]), - alias, - ), - ); + let plan = sort_required_exec(projection_exec_with_alias( + parquet_exec_multiple_sorted(vec![sort_key]), + alias, + )); let expected = &[ - "SortPreservingMergeExec: [c@2 ASC]", + "SortRequiredExec: [c@2 ASC]", // Since this projection is trivial, increasing parallelism is not beneficial "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", @@ -3740,14 +3641,14 @@ mod tests { fn repartition_transitively_past_sort_with_filter() -> Result<()> { let schema = schema(); let sort_key = vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + expr: col("a", &schema).unwrap(), options: SortOptions::default(), }]; let plan = sort_exec(sort_key, filter_exec(parquet_exec()), false); let expected = &[ - "SortPreservingMergeExec: [c@2 ASC]", - "SortExec: expr=[c@2 ASC]", + "SortPreservingMergeExec: [a@0 ASC]", + "SortExec: expr=[a@0 ASC]", // Expect repartition on the input to the sort (as it can benefit from additional parallelism) "FilterExec: c@2 = 0", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -3757,7 +3658,7 @@ mod tests { assert_optimized!(expected, plan.clone(), true); let expected_first_sort_enforcement = &[ - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[a@0 ASC]", "CoalescePartitionsExec", "FilterExec: c@2 = 0", // Expect repartition on the input of the filter (as it can benefit from additional parallelism) @@ -3769,26 +3670,31 @@ mod tests { } #[test] + #[cfg(feature = "parquet")] fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> { let schema = schema(); let sort_key = vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + expr: col("a", &schema).unwrap(), options: SortOptions::default(), }]; let plan = sort_exec( sort_key, projection_exec_with_alias( filter_exec(parquet_exec()), - vec![("a".to_string(), "a".to_string())], + vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "b".to_string()), + ("c".to_string(), "c".to_string()), + ], ), false, ); let expected = &[ - "SortPreservingMergeExec: [c@2 ASC]", + "SortPreservingMergeExec: [a@0 ASC]", // Expect repartition on the input to the sort (as it can benefit from additional parallelism) - "SortExec: expr=[c@2 ASC]", - "ProjectionExec: expr=[a@0 as a]", + "SortExec: expr=[a@0 ASC]", + "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", "FilterExec: c@2 = 0", // repartition is lowest down "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -3798,9 +3704,9 @@ mod tests { assert_optimized!(expected, plan.clone(), true); let expected_first_sort_enforcement = &[ - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[a@0 ASC]", "CoalescePartitionsExec", - "ProjectionExec: expr=[a@0 as a]", + "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", "FilterExec: c@2 = 0", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -3833,6 +3739,56 @@ mod tests { Ok(()) } + #[test] + fn parallelization_multiple_files() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]; + + let plan = filter_exec(parquet_exec_multiple_sorted(vec![sort_key])); + let plan = sort_required_exec(plan); + + // The groups must have only contiguous ranges of rows from the same file + // if any group has rows from multiple files, the data is no longer sorted destroyed + // https://github.com/apache/arrow-datafusion/issues/8451 + let expected = [ + "SortRequiredExec: [a@0 ASC]", + "FilterExec: c@2 = 0", + "ParquetExec: file_groups={3 groups: [[x:0..50], [y:0..100], [x:50..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", ]; + let target_partitions = 3; + let repartition_size = 1; + assert_optimized!( + expected, + plan, + true, + true, + target_partitions, + true, + repartition_size + ); + + let expected = [ + "SortRequiredExec: [a@0 ASC]", + "FilterExec: c@2 = 0", + "ParquetExec: file_groups={8 groups: [[x:0..25], [y:0..25], [x:25..50], [y:25..50], [x:50..75], [y:50..75], [x:75..100], [y:75..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + let target_partitions = 8; + let repartition_size = 1; + assert_optimized!( + expected, + plan, + true, + true, + target_partitions, + true, + repartition_size + ); + + Ok(()) + } + #[test] /// CsvExec on compressed csv file will not be partitioned /// (Not able to decompress chunked csv file) @@ -3876,12 +3832,11 @@ mod tests { "x".to_string(), 100, )]], - statistics: Statistics::default(), + statistics: Statistics::new_unknown(&schema()), projection: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, false, b',', @@ -3909,14 +3864,14 @@ mod tests { "RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", // Plan already has two partitions - "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e]", + "ParquetExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e]", ]; let expected_csv = [ "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", "RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", // Plan already has two partitions - "CsvExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], has_header=false", + "CsvExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e], has_header=false", ]; assert_optimized!(expected_parquet, plan_parquet, true, false, 2, true, 10); @@ -4186,11 +4141,11 @@ mod tests { // no parallelization, because SortRequiredExec doesn't benefit from increased parallelism let expected_parquet = &[ - "SortRequiredExec", + "SortRequiredExec: [c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", ]; let expected_csv = &[ - "SortRequiredExec", + "SortRequiredExec: [c@2 ASC]", "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], has_header=false", ]; @@ -4318,11 +4273,11 @@ mod tests { let expected = &[ "SortPreservingMergeExec: [c@2 ASC]", "FilterExec: c@2 = 0", - "SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c@2 ASC", "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", ]; - // last flag sets config.optimizer.bounded_order_preserving_variants + // last flag sets config.optimizer.PREFER_EXISTING_SORT assert_optimized!(expected, physical_plan.clone(), true, true); assert_optimized!(expected, physical_plan, false, true); @@ -4331,6 +4286,38 @@ mod tests { #[test] fn do_not_preserve_ordering_through_repartition() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]; + let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); + let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); + + let expected = &[ + "SortPreservingMergeExec: [a@0 ASC]", + "SortExec: expr=[a@0 ASC]", + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + + assert_optimized!(expected, physical_plan.clone(), true); + + let expected = &[ + "SortExec: expr=[a@0 ASC]", + "CoalescePartitionsExec", + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + assert_optimized!(expected, physical_plan, false); + + Ok(()) + } + + #[test] + fn no_need_for_sort_after_filter() -> Result<()> { let schema = schema(); let sort_key = vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), @@ -4340,8 +4327,9 @@ mod tests { let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); let expected = &[ - "SortPreservingMergeExec: [c@2 ASC]", - "SortExec: expr=[c@2 ASC]", + // After CoalescePartitionsExec c is still constant. Hence c@2 ASC ordering is already satisfied. + "CoalescePartitionsExec", + // Since after this stage c is constant. c@2 ASC ordering is already satisfied. "FilterExec: c@2 = 0", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", @@ -4377,6 +4365,15 @@ mod tests { ]; assert_optimized!(expected, physical_plan.clone(), true); + + let expected = &[ + "SortExec: expr=[a@0 ASC]", + "CoalescePartitionsExec", + "SortExec: expr=[a@0 ASC]", + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; assert_optimized!(expected, physical_plan, false); Ok(()) @@ -4404,6 +4401,82 @@ mod tests { Ok(()) } + #[test] + fn do_not_put_sort_when_input_is_invalid() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]; + let input = parquet_exec(); + let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); + let expected = &[ + // Ordering requirement of sort required exec is NOT satisfied + // by existing ordering at the source. + "SortRequiredExec: [a@0 ASC]", + "FilterExec: c@2 = 0", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + assert_plan_txt!(expected, physical_plan); + + let expected = &[ + "SortRequiredExec: [a@0 ASC]", + // Since at the start of the rule ordering requirement is not satisfied + // EnforceDistribution rule doesn't satisfy this requirement either. + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + let mut config = ConfigOptions::new(); + config.execution.target_partitions = 10; + config.optimizer.enable_round_robin_repartition = true; + config.optimizer.prefer_existing_sort = false; + let distribution_plan = + EnforceDistribution::new().optimize(physical_plan, &config)?; + assert_plan_txt!(expected, distribution_plan); + + Ok(()) + } + + #[test] + fn put_sort_when_input_is_valid() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]; + let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); + let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); + + let expected = &[ + // Ordering requirement of sort required exec is satisfied + // by existing ordering at the source. + "SortRequiredExec: [a@0 ASC]", + "FilterExec: c@2 = 0", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + assert_plan_txt!(expected, physical_plan); + + let expected = &[ + // Since at the start of the rule ordering requirement is satisfied + // EnforceDistribution rule satisfy this requirement also. + "SortRequiredExec: [a@0 ASC]", + "FilterExec: c@2 = 0", + "ParquetExec: file_groups={10 groups: [[x:0..20], [y:0..20], [x:20..40], [y:20..40], [x:40..60], [y:40..60], [x:60..80], [y:60..80], [x:80..100], [y:80..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + + let mut config = ConfigOptions::new(); + config.execution.target_partitions = 10; + config.optimizer.enable_round_robin_repartition = true; + config.optimizer.prefer_existing_sort = false; + let distribution_plan = + EnforceDistribution::new().optimize(physical_plan, &config)?; + assert_plan_txt!(expected, distribution_plan); + + Ok(()) + } + #[test] fn do_not_add_unnecessary_hash() -> Result<()> { let schema = schema(); @@ -4458,4 +4531,51 @@ mod tests { Ok(()) } + + #[test] + fn optimize_away_unnecessary_repartition() -> Result<()> { + let physical_plan = coalesce_partitions_exec(repartition_exec(parquet_exec())); + let expected = &[ + "CoalescePartitionsExec", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + plans_matches_expected!(expected, physical_plan.clone()); + + let expected = + &["ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]"]; + + assert_optimized!(expected, physical_plan.clone(), true); + assert_optimized!(expected, physical_plan, false); + + Ok(()) + } + + #[test] + fn optimize_away_unnecessary_repartition2() -> Result<()> { + let physical_plan = filter_exec(repartition_exec(coalesce_partitions_exec( + filter_exec(repartition_exec(parquet_exec())), + ))); + let expected = &[ + "FilterExec: c@2 = 0", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CoalescePartitionsExec", + " FilterExec: c@2 = 0", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + plans_matches_expected!(expected, physical_plan.clone()); + + let expected = &[ + "FilterExec: c@2 = 0", + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, physical_plan.clone(), true); + assert_optimized!(expected, physical_plan, false); + + Ok(()) + } } diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index a149330181d9..f609ddea66cf 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -17,8 +17,8 @@ //! EnforceSorting optimizer rule inspects the physical plan with respect //! to local sorting requirements and does the following: -//! - Adds a [SortExec] when a requirement is not met, -//! - Removes an already-existing [SortExec] if it is possible to prove +//! - Adds a [`SortExec`] when a requirement is not met, +//! - Removes an already-existing [`SortExec`] if it is possible to prove //! that this sort is unnecessary //! The rule can work on valid *and* invalid physical plans with respect to //! sorting requirements, but always produces a valid physical plan in this sense. @@ -34,6 +34,7 @@ //! in the physical plan. The first sort is unnecessary since its result is overwritten //! by another [`SortExec`]. Therefore, this rule removes it from the physical plan. +use std::borrow::Cow; use std::sync::Arc; use crate::config::ConfigOptions; @@ -44,23 +45,23 @@ use crate::physical_optimizer::replace_with_order_preserving_variants::{ use crate::physical_optimizer::sort_pushdown::{pushdown_sorts, SortPushDown}; use crate::physical_optimizer::utils::{ add_sort_above, is_coalesce_partitions, is_limit, is_repartition, is_sort, - is_sort_preserving_merge, is_union, is_window, ExecTree, + is_sort_preserving_merge, is_union, is_window, }; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::windows::{ - get_best_fitting_window, BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, + get_best_fitting_window, BoundedWindowAggExec, WindowAggExec, +}; +use crate::physical_plan::{ + with_new_children_if_necessary, Distribution, ExecutionPlan, InputOrderMode, }; -use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{plan_err, DataFusionError}; -use datafusion_physical_expr::utils::{ - ordering_satisfy, ordering_satisfy_requirement_concrete, -}; use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; +use datafusion_physical_plan::repartition::RepartitionExec; use itertools::izip; @@ -81,231 +82,172 @@ impl EnforceSorting { #[derive(Debug, Clone)] struct PlanWithCorrespondingSort { plan: Arc, - // For every child, keep a subtree of `ExecutionPlan`s starting from the - // child until the `SortExec`(s) -- could be multiple for n-ary plans like - // Union -- that determine the output ordering of the child. If the child - // has no connection to any sort, simply store None (and not a subtree). - sort_onwards: Vec>, + // For every child, track `ExecutionPlan`s starting from the child until + // the `SortExec`(s). If the child has no connection to any sort, it simply + // stores false. + sort_connection: bool, + children_nodes: Vec, } impl PlanWithCorrespondingSort { fn new(plan: Arc) -> Self { - let length = plan.children().len(); - PlanWithCorrespondingSort { + let children = plan.children(); + Self { plan, - sort_onwards: vec![None; length], + sort_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } - fn new_from_children_nodes( - children_nodes: Vec, + fn update_children( parent_plan: Arc, + mut children_nodes: Vec, ) -> Result { - let children_plans = children_nodes - .iter() - .map(|item| item.plan.clone()) - .collect::>(); - let sort_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, item)| { - let plan = &item.plan; - // Leaves of `sort_onwards` are `SortExec` operators, which impose - // an ordering. This tree collects all the intermediate executors - // that maintain this ordering. If we just saw a order imposing - // operator, we reset the tree and start accumulating. - if is_sort(plan) { - return Some(ExecTree::new(item.plan, idx, vec![])); - } else if is_limit(plan) { - // There is no sort linkage for this path, it starts at a limit. - return None; - } + for node in children_nodes.iter_mut() { + let plan = &node.plan; + // Leaves of `sort_onwards` are `SortExec` operators, which impose + // an ordering. This tree collects all the intermediate executors + // that maintain this ordering. If we just saw a order imposing + // operator, we reset the tree and start accumulating. + node.sort_connection = if is_sort(plan) { + // Initiate connection + true + } else if is_limit(plan) { + // There is no sort linkage for this path, it starts at a limit. + false + } else { let is_spm = is_sort_preserving_merge(plan); let required_orderings = plan.required_input_ordering(); let flags = plan.maintains_input_order(); - let children = izip!(flags, item.sort_onwards, required_orderings) - .filter_map(|(maintains, element, required_ordering)| { - if (required_ordering.is_none() && maintains) || is_spm { - element - } else { - None - } - }) - .collect::>(); - if !children.is_empty() { - // Add parent node to the tree if there is at least one - // child with a subtree: - Some(ExecTree::new(item.plan, idx, children)) - } else { - // There is no sort linkage for this child, do nothing. - None - } - }) - .collect(); + // Add parent node to the tree if there is at least one + // child with a sort connection: + izip!(flags, required_orderings).any(|(maintains, required_ordering)| { + let propagates_ordering = + (maintains && required_ordering.is_none()) || is_spm; + let connected_to_sort = + node.children_nodes.iter().any(|item| item.sort_connection); + propagates_ordering && connected_to_sort + }) + } + } + let children_plans = children_nodes + .iter() + .map(|item| item.plan.clone()) + .collect::>(); let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); - Ok(PlanWithCorrespondingSort { plan, sort_onwards }) - } - fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(|child| PlanWithCorrespondingSort::new(child)) - .collect() + Ok(Self { + plan, + sort_connection: false, + children_nodes, + }) } } impl TreeNode for PlanWithCorrespondingSort { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - PlanWithCorrespondingSort::new_from_children_nodes(children_nodes, self.plan) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } -/// This object is used within the [EnforceSorting] rule to track the closest +/// This object is used within the [`EnforceSorting`] rule to track the closest /// [`CoalescePartitionsExec`] descendant(s) for every child of a plan. #[derive(Debug, Clone)] struct PlanWithCorrespondingCoalescePartitions { plan: Arc, - // For every child, keep a subtree of `ExecutionPlan`s starting from the - // child until the `CoalescePartitionsExec`(s) -- could be multiple for - // n-ary plans like Union -- that affect the output partitioning of the - // child. If the child has no connection to any `CoalescePartitionsExec`, - // simply store None (and not a subtree). - coalesce_onwards: Vec>, + // Stores whether the plan is a `CoalescePartitionsExec` or it is connected to + // a `CoalescePartitionsExec` via its children. + coalesce_connection: bool, + children_nodes: Vec, } impl PlanWithCorrespondingCoalescePartitions { + /// Creates an empty tree with empty connections. fn new(plan: Arc) -> Self { - let length = plan.children().len(); - PlanWithCorrespondingCoalescePartitions { + let children = plan.children(); + Self { plan, - coalesce_onwards: vec![None; length], + coalesce_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } - fn new_from_children_nodes( - children_nodes: Vec, - parent_plan: Arc, - ) -> Result { - let children_plans = children_nodes + fn update_children(mut self) -> Result { + self.coalesce_connection = if self.plan.children().is_empty() { + // Plan has no children, it cannot be a `CoalescePartitionsExec`. + false + } else if is_coalesce_partitions(&self.plan) { + // Initiate a connection + true + } else { + self.children_nodes + .iter() + .enumerate() + .map(|(idx, node)| { + // Only consider operators that don't require a + // single partition, and connected to any coalesce + node.coalesce_connection + && !matches!( + self.plan.required_input_distribution()[idx], + Distribution::SinglePartition + ) + // If all children are None. There is nothing to track, set connection false. + }) + .any(|c| c) + }; + + let children_plans = self + .children_nodes .iter() .map(|item| item.plan.clone()) .collect(); - let coalesce_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, item)| { - // Leaves of the `coalesce_onwards` tree are `CoalescePartitionsExec` - // operators. This tree collects all the intermediate executors that - // maintain a single partition. If we just saw a `CoalescePartitionsExec` - // operator, we reset the tree and start accumulating. - let plan = item.plan; - if plan.children().is_empty() { - // Plan has no children, there is nothing to propagate. - None - } else if is_coalesce_partitions(&plan) { - Some(ExecTree::new(plan, idx, vec![])) - } else { - let children = item - .coalesce_onwards - .into_iter() - .flatten() - .filter(|item| { - // Only consider operators that don't require a - // single partition. - !matches!( - plan.required_input_distribution()[item.idx], - Distribution::SinglePartition - ) - }) - .collect::>(); - if children.is_empty() { - None - } else { - Some(ExecTree::new(plan, idx, children)) - } - } - }) - .collect(); - let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); - Ok(PlanWithCorrespondingCoalescePartitions { - plan, - coalesce_onwards, - }) - } - - fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(|child| PlanWithCorrespondingCoalescePartitions::new(child)) - .collect() + self.plan = with_new_children_if_necessary(self.plan, children_plans)?.into(); + Ok(self) } } impl TreeNode for PlanWithCorrespondingCoalescePartitions { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - PlanWithCorrespondingCoalescePartitions::new_from_children_nodes( - children_nodes, + .collect::>()?; + self.plan = with_new_children_if_necessary( self.plan, - ) + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } @@ -332,6 +274,7 @@ impl PhysicalOptimizerRule for EnforceSorting { } else { adjusted.plan }; + let plan_with_pipeline_fixer = OrderPreservationContext::new(new_plan); let updated_plan = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| { @@ -345,7 +288,8 @@ impl PhysicalOptimizerRule for EnforceSorting { // Execute a top-down traversal to exploit sort push-down opportunities // missed by the bottom-up traversal: - let sort_pushdown = SortPushDown::init(updated_plan.plan); + let mut sort_pushdown = SortPushDown::new(updated_plan.plan); + sort_pushdown.assign_initial_requirements(); let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?; Ok(adjusted.plan) } @@ -376,16 +320,21 @@ impl PhysicalOptimizerRule for EnforceSorting { fn parallelize_sorts( requirements: PlanWithCorrespondingCoalescePartitions, ) -> Result> { - let plan = requirements.plan; - let mut coalesce_onwards = requirements.coalesce_onwards; - if plan.children().is_empty() || coalesce_onwards[0].is_none() { + let PlanWithCorrespondingCoalescePartitions { + mut plan, + coalesce_connection, + mut children_nodes, + } = requirements.update_children()?; + + if plan.children().is_empty() || !children_nodes[0].coalesce_connection { // We only take an action when the plan is either a SortExec, a // SortPreservingMergeExec or a CoalescePartitionsExec, and they // all have a single child. Therefore, if the first child is `None`, // we can return immediately. return Ok(Transformed::No(PlanWithCorrespondingCoalescePartitions { plan, - coalesce_onwards, + coalesce_connection, + children_nodes, })); } else if (is_sort(&plan) || is_sort_preserving_merge(&plan)) && plan.output_partitioning().partition_count() <= 1 @@ -395,30 +344,30 @@ fn parallelize_sorts( // executors don't require single partition), then we can replace // the CoalescePartitionsExec + Sort cascade with a SortExec + // SortPreservingMergeExec cascade to parallelize sorting. - let mut prev_layer = plan.clone(); - update_child_to_remove_coalesce(&mut prev_layer, &mut coalesce_onwards[0])?; let (sort_exprs, fetch) = get_sort_exprs(&plan)?; - add_sort_above(&mut prev_layer, sort_exprs.to_vec(), fetch)?; - let spm = SortPreservingMergeExec::new(sort_exprs.to_vec(), prev_layer) - .with_fetch(fetch); - return Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { - plan: Arc::new(spm), - coalesce_onwards: vec![None], - })); + let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs); + let sort_exprs = sort_exprs.to_vec(); + update_child_to_remove_coalesce(&mut plan, &mut children_nodes[0])?; + add_sort_above(&mut plan, &sort_reqs, fetch); + let spm = SortPreservingMergeExec::new(sort_exprs, plan).with_fetch(fetch); + + return Ok(Transformed::Yes( + PlanWithCorrespondingCoalescePartitions::new(Arc::new(spm)), + )); } else if is_coalesce_partitions(&plan) { // There is an unnecessary `CoalescePartitionsExec` in the plan. - let mut prev_layer = plan.clone(); - update_child_to_remove_coalesce(&mut prev_layer, &mut coalesce_onwards[0])?; - let new_plan = plan.with_new_children(vec![prev_layer])?; - return Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { - plan: new_plan, - coalesce_onwards: vec![None], - })); + update_child_to_remove_coalesce(&mut plan, &mut children_nodes[0])?; + + let new_plan = Arc::new(CoalescePartitionsExec::new(plan)) as _; + return Ok(Transformed::Yes( + PlanWithCorrespondingCoalescePartitions::new(new_plan), + )); } Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { plan, - coalesce_onwards, + coalesce_connection, + children_nodes, })) } @@ -427,99 +376,107 @@ fn parallelize_sorts( fn ensure_sorting( requirements: PlanWithCorrespondingSort, ) -> Result> { + let requirements = PlanWithCorrespondingSort::update_children( + requirements.plan, + requirements.children_nodes, + )?; + // Perform naive analysis at the beginning -- remove already-satisfied sorts: if requirements.plan.children().is_empty() { return Ok(Transformed::No(requirements)); } - let plan = requirements.plan; - let mut children = plan.children(); - let mut sort_onwards = requirements.sort_onwards; - if let Some(result) = analyze_immediate_sort_removal(&plan, &sort_onwards) { + if let Some(result) = analyze_immediate_sort_removal(&requirements) { return Ok(Transformed::Yes(result)); } - for (idx, (child, sort_onwards, required_ordering)) in izip!( - children.iter_mut(), - sort_onwards.iter_mut(), - plan.required_input_ordering() - ) - .enumerate() + + let plan = requirements.plan; + let mut children_nodes = requirements.children_nodes; + + for (idx, (child_node, required_ordering)) in + izip!(children_nodes.iter_mut(), plan.required_input_ordering()).enumerate() { - let physical_ordering = child.output_ordering(); + let mut child_plan = child_node.plan.clone(); + let physical_ordering = child_plan.output_ordering(); match (required_ordering, physical_ordering) { - (Some(required_ordering), Some(physical_ordering)) => { - if !ordering_satisfy_requirement_concrete( - physical_ordering, - &required_ordering, - || child.equivalence_properties(), - || child.ordering_equivalence_properties(), - ) { + (Some(required_ordering), Some(_)) => { + if !child_plan + .equivalence_properties() + .ordering_satisfy_requirement(&required_ordering) + { // Make sure we preserve the ordering requirements: - update_child_to_remove_unnecessary_sort(child, sort_onwards, &plan)?; - let sort_expr = - PhysicalSortRequirement::to_sort_exprs(required_ordering); - add_sort_above(child, sort_expr, None)?; - if is_sort(child) { - *sort_onwards = Some(ExecTree::new(child.clone(), idx, vec![])); - } else { - *sort_onwards = None; + update_child_to_remove_unnecessary_sort(idx, child_node, &plan)?; + add_sort_above(&mut child_plan, &required_ordering, None); + if is_sort(&child_plan) { + *child_node = PlanWithCorrespondingSort::update_children( + child_plan, + vec![child_node.clone()], + )?; + child_node.sort_connection = true; } } } (Some(required), None) => { // Ordering requirement is not met, we should add a `SortExec` to the plan. - let sort_expr = PhysicalSortRequirement::to_sort_exprs(required); - add_sort_above(child, sort_expr, None)?; - *sort_onwards = Some(ExecTree::new(child.clone(), idx, vec![])); + add_sort_above(&mut child_plan, &required, None); + *child_node = PlanWithCorrespondingSort::update_children( + child_plan, + vec![child_node.clone()], + )?; + child_node.sort_connection = true; } (None, Some(_)) => { // We have a `SortExec` whose effect may be neutralized by // another order-imposing operator. Remove this sort. if !plan.maintains_input_order()[idx] || is_union(&plan) { - update_child_to_remove_unnecessary_sort(child, sort_onwards, &plan)?; + update_child_to_remove_unnecessary_sort(idx, child_node, &plan)?; } } - (None, None) => {} + (None, None) => { + update_child_to_remove_unnecessary_sort(idx, child_node, &plan)?; + } } } // For window expressions, we can remove some sorts when we can // calculate the result in reverse: - if is_window(&plan) { - if let Some(tree) = &mut sort_onwards[0] { - if let Some(result) = analyze_window_sort_removal(tree, &plan)? { - return Ok(Transformed::Yes(result)); - } + if is_window(&plan) && children_nodes[0].sort_connection { + if let Some(result) = analyze_window_sort_removal(&mut children_nodes[0], &plan)? + { + return Ok(Transformed::Yes(result)); } } else if is_sort_preserving_merge(&plan) - && children[0].output_partitioning().partition_count() <= 1 + && children_nodes[0] + .plan + .output_partitioning() + .partition_count() + <= 1 { // This SortPreservingMergeExec is unnecessary, input already has a // single partition. - return Ok(Transformed::Yes(PlanWithCorrespondingSort { - plan: children[0].clone(), - sort_onwards: vec![sort_onwards[0].clone()], - })); + let child_node = children_nodes.swap_remove(0); + return Ok(Transformed::Yes(child_node)); } - Ok(Transformed::Yes(PlanWithCorrespondingSort { - plan: plan.with_new_children(children)?, - sort_onwards, - })) + Ok(Transformed::Yes( + PlanWithCorrespondingSort::update_children(plan, children_nodes)?, + )) } /// Analyzes a given [`SortExec`] (`plan`) to determine whether its input /// already has a finer ordering than it enforces. fn analyze_immediate_sort_removal( - plan: &Arc, - sort_onwards: &[Option], + node: &PlanWithCorrespondingSort, ) -> Option { + let PlanWithCorrespondingSort { + plan, + children_nodes, + .. + } = node; if let Some(sort_exec) = plan.as_any().downcast_ref::() { let sort_input = sort_exec.input().clone(); // If this sort is unnecessary, we should remove it: - if ordering_satisfy( - sort_input.output_ordering(), - sort_exec.output_ordering(), - || sort_input.equivalence_properties(), - || sort_input.ordering_equivalence_properties(), - ) { + if sort_input + .equivalence_properties() + .ordering_satisfy(sort_exec.output_ordering().unwrap_or(&[])) + { // Since we know that a `SortExec` has exactly one child, // we can use the zero index safely: return Some( @@ -532,20 +489,33 @@ fn analyze_immediate_sort_removal( sort_exec.expr().to_vec(), sort_input, )); - let new_tree = ExecTree::new( - new_plan.clone(), - 0, - sort_onwards.iter().flat_map(|e| e.clone()).collect(), - ); PlanWithCorrespondingSort { plan: new_plan, - sort_onwards: vec![Some(new_tree)], + // SortPreservingMergeExec has single child. + sort_connection: false, + children_nodes: children_nodes + .iter() + .cloned() + .map(|mut node| { + node.sort_connection = false; + node + }) + .collect(), } } else { // Remove the sort: PlanWithCorrespondingSort { plan: sort_input, - sort_onwards: sort_onwards.to_vec(), + sort_connection: false, + children_nodes: children_nodes[0] + .children_nodes + .iter() + .cloned() + .map(|mut node| { + node.sort_connection = false; + node + }) + .collect(), } }, ); @@ -557,16 +527,15 @@ fn analyze_immediate_sort_removal( /// Analyzes a [`WindowAggExec`] or a [`BoundedWindowAggExec`] to determine /// whether it may allow removing a sort. fn analyze_window_sort_removal( - sort_tree: &mut ExecTree, + sort_tree: &mut PlanWithCorrespondingSort, window_exec: &Arc, ) -> Result> { let requires_single_partition = matches!( - window_exec.required_input_distribution()[sort_tree.idx], + window_exec.required_input_distribution()[0], Distribution::SinglePartition ); - let mut window_child = - remove_corresponding_sort_from_sub_plan(sort_tree, requires_single_partition)?; - + remove_corresponding_sort_from_sub_plan(sort_tree, requires_single_partition)?; + let mut window_child = sort_tree.plan.clone(); let (window_expr, new_window) = if let Some(exec) = window_exec.as_any().downcast_ref::() { ( @@ -602,26 +571,22 @@ fn analyze_window_sort_removal( let reqs = window_exec .required_input_ordering() .swap_remove(0) - .unwrap_or(vec![]); - let sort_expr = PhysicalSortRequirement::to_sort_exprs(reqs); + .unwrap_or_default(); // Satisfy the ordering requirement so that the window can run: - add_sort_above(&mut window_child, sort_expr, None)?; + add_sort_above(&mut window_child, &reqs, None); let uses_bounded_memory = window_expr.iter().all(|e| e.uses_bounded_memory()); - let input_schema = window_child.schema(); let new_window = if uses_bounded_memory { Arc::new(BoundedWindowAggExec::try_new( window_expr.to_vec(), window_child, - input_schema, partitionby_exprs.to_vec(), - PartitionSearchMode::Sorted, + InputOrderMode::Sorted, )?) as _ } else { Arc::new(WindowAggExec::try_new( window_expr.to_vec(), window_child, - input_schema, partitionby_exprs.to_vec(), )?) as _ }; @@ -632,9 +597,9 @@ fn analyze_window_sort_removal( /// Updates child to remove the unnecessary [`CoalescePartitionsExec`] below it. fn update_child_to_remove_coalesce( child: &mut Arc, - coalesce_onwards: &mut Option, + coalesce_onwards: &mut PlanWithCorrespondingCoalescePartitions, ) -> Result<()> { - if let Some(coalesce_onwards) = coalesce_onwards { + if coalesce_onwards.coalesce_connection { *child = remove_corresponding_coalesce_in_sub_plan(coalesce_onwards, child)?; } Ok(()) @@ -642,91 +607,125 @@ fn update_child_to_remove_coalesce( /// Removes the [`CoalescePartitionsExec`] from the plan in `coalesce_onwards`. fn remove_corresponding_coalesce_in_sub_plan( - coalesce_onwards: &mut ExecTree, + coalesce_onwards: &mut PlanWithCorrespondingCoalescePartitions, parent: &Arc, ) -> Result> { - Ok(if is_coalesce_partitions(&coalesce_onwards.plan) { + if is_coalesce_partitions(&coalesce_onwards.plan) { // We can safely use the 0th index since we have a `CoalescePartitionsExec`. let mut new_plan = coalesce_onwards.plan.children()[0].clone(); while new_plan.output_partitioning() == parent.output_partitioning() && is_repartition(&new_plan) && is_repartition(parent) { - new_plan = new_plan.children()[0].clone() + new_plan = new_plan.children().swap_remove(0) } - new_plan + Ok(new_plan) } else { let plan = coalesce_onwards.plan.clone(); let mut children = plan.children(); - for item in &mut coalesce_onwards.children { - children[item.idx] = remove_corresponding_coalesce_in_sub_plan(item, &plan)?; + for (idx, node) in coalesce_onwards.children_nodes.iter_mut().enumerate() { + if node.coalesce_connection { + children[idx] = remove_corresponding_coalesce_in_sub_plan(node, &plan)?; + } } - plan.with_new_children(children)? - }) + plan.with_new_children(children) + } } /// Updates child to remove the unnecessary sort below it. fn update_child_to_remove_unnecessary_sort( - child: &mut Arc, - sort_onwards: &mut Option, + child_idx: usize, + sort_onwards: &mut PlanWithCorrespondingSort, parent: &Arc, ) -> Result<()> { - if let Some(sort_onwards) = sort_onwards { + if sort_onwards.sort_connection { let requires_single_partition = matches!( - parent.required_input_distribution()[sort_onwards.idx], + parent.required_input_distribution()[child_idx], Distribution::SinglePartition ); - *child = remove_corresponding_sort_from_sub_plan( - sort_onwards, - requires_single_partition, - )?; + remove_corresponding_sort_from_sub_plan(sort_onwards, requires_single_partition)?; } - *sort_onwards = None; + sort_onwards.sort_connection = false; Ok(()) } /// Removes the sort from the plan in `sort_onwards`. fn remove_corresponding_sort_from_sub_plan( - sort_onwards: &mut ExecTree, + sort_onwards: &mut PlanWithCorrespondingSort, requires_single_partition: bool, -) -> Result> { +) -> Result<()> { // A `SortExec` is always at the bottom of the tree. - let mut updated_plan = if is_sort(&sort_onwards.plan) { - sort_onwards.plan.children()[0].clone() + if is_sort(&sort_onwards.plan) { + *sort_onwards = sort_onwards.children_nodes.swap_remove(0); } else { - let plan = &sort_onwards.plan; - let mut children = plan.children(); - for item in &mut sort_onwards.children { - let requires_single_partition = matches!( - plan.required_input_distribution()[item.idx], - Distribution::SinglePartition - ); - children[item.idx] = - remove_corresponding_sort_from_sub_plan(item, requires_single_partition)?; + let PlanWithCorrespondingSort { + plan, + sort_connection: _, + children_nodes, + } = sort_onwards; + let mut any_connection = false; + for (child_idx, child_node) in children_nodes.iter_mut().enumerate() { + if child_node.sort_connection { + any_connection = true; + let requires_single_partition = matches!( + plan.required_input_distribution()[child_idx], + Distribution::SinglePartition + ); + remove_corresponding_sort_from_sub_plan( + child_node, + requires_single_partition, + )?; + } } + if any_connection || children_nodes.is_empty() { + *sort_onwards = PlanWithCorrespondingSort::update_children( + plan.clone(), + children_nodes.clone(), + )?; + } + let PlanWithCorrespondingSort { + plan, + children_nodes, + .. + } = sort_onwards; + // Replace with variants that do not preserve order. if is_sort_preserving_merge(plan) { - children[0].clone() - } else { - plan.clone().with_new_children(children)? + children_nodes.swap_remove(0); + *plan = plan.children().swap_remove(0); + } else if let Some(repartition) = plan.as_any().downcast_ref::() + { + *plan = Arc::new(RepartitionExec::try_new( + children_nodes[0].plan.clone(), + repartition.output_partitioning(), + )?) as _; } }; // Deleting a merging sort may invalidate distribution requirements. // Ensure that we stay compliant with such requirements: if requires_single_partition - && updated_plan.output_partitioning().partition_count() > 1 + && sort_onwards.plan.output_partitioning().partition_count() > 1 { // If there is existing ordering, to preserve ordering use SortPreservingMergeExec // instead of CoalescePartitionsExec. - if let Some(ordering) = updated_plan.output_ordering() { - updated_plan = Arc::new(SortPreservingMergeExec::new( + if let Some(ordering) = sort_onwards.plan.output_ordering() { + let plan = Arc::new(SortPreservingMergeExec::new( ordering.to_vec(), - updated_plan, - )); + sort_onwards.plan.clone(), + )) as _; + *sort_onwards = PlanWithCorrespondingSort::update_children( + plan, + vec![sort_onwards.clone()], + )?; } else { - updated_plan = Arc::new(CoalescePartitionsExec::new(updated_plan.clone())); + let plan = + Arc::new(CoalescePartitionsExec::new(sort_onwards.plan.clone())) as _; + *sort_onwards = PlanWithCorrespondingSort::update_children( + plan, + vec![sort_onwards.clone()], + )?; } } - Ok(updated_plan) + Ok(()) } /// Converts an [ExecutionPlan] trait object to a [PhysicalSortExpr] slice when possible. @@ -758,20 +757,20 @@ mod tests { coalesce_partitions_exec, filter_exec, global_limit_exec, hash_join_exec, limit_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_sorted, repartition_exec, sort_exec, sort_expr, sort_expr_options, sort_merge_join_exec, - sort_preserving_merge_exec, union_exec, + sort_preserving_merge_exec, spr_repartition_exec, union_exec, }; - use crate::physical_optimizer::utils::get_plan_string; use crate::physical_plan::repartition::RepartitionExec; - use crate::physical_plan::{displayable, Partitioning}; + use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::{SessionConfig, SessionContext}; - use crate::test::csv_exec_sorted; + use crate::test::{csv_exec_ordered, csv_exec_sorted, stream_exec_ordered}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::Result; use datafusion_expr::JoinType; - use datafusion_physical_expr::expressions::Column; - use datafusion_physical_expr::expressions::{col, NotExpr}; + use datafusion_physical_expr::expressions::{col, Column, NotExpr}; + + use rstest::rstest; fn create_test_schema() -> Result { let nullable_column = Field::new("nullable_col", DataType::Int32, true); @@ -810,7 +809,7 @@ mod tests { macro_rules! assert_optimized { ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $REPARTITION_SORTS: expr) => { let config = SessionConfig::new().with_repartition_sorts($REPARTITION_SORTS); - let session_ctx = SessionContext::with_config(config); + let session_ctx = SessionContext::new_with_config(config); let state = session_ctx.state(); let physical_plan = $PLAN; @@ -1635,14 +1634,16 @@ mod tests { // During the removal of `SortExec`s, it should be able to remove the // corresponding SortExecs together. Also, the inputs of these `SortExec`s // are not necessarily the same to be able to remove them. - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " SortPreservingMergeExec: [nullable_col@0 DESC NULLS LAST]", " UnionExec", " SortExec: expr=[nullable_col@0 DESC NULLS LAST]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 DESC NULLS LAST]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]"]; - let expected_optimized = ["WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL) }]", + let expected_optimized = [ + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL) }]", " SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC]", @@ -2035,7 +2036,6 @@ mod tests { let orig_plan = Arc::new(SortExec::new(sort_exprs, repartition)) as Arc; let actual = get_plan_string(&orig_plan); - println!("{:?}", actual); let expected_input = vec![ "SortExec: expr=[nullable_col@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -2112,7 +2112,7 @@ mod tests { async fn test_with_lost_ordering_bounded() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, false); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2135,11 +2135,19 @@ mod tests { Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_lost_ordering_unbounded() -> Result<()> { + async fn test_with_lost_ordering_unbounded_bounded( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + // create either bounded or unbounded source + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_ordered(&schema, sort_exprs) + }; let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2148,42 +2156,71 @@ mod tests { let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); - let expected_input = ["SortExec: expr=[a@0 ASC]", + // Expected inputs unbounded and bounded + let expected_input_unbounded = vec![ + "SortExec: expr=[a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC]", - " SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); - Ok(()) - } + " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", + ]; + let expected_input_bounded = vec![ + "SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=true", + ]; - #[tokio::test] - async fn test_with_lost_ordering_unbounded_parallelize_off() -> Result<()> { - let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); - let repartition_rr = repartition_exec(source); - let repartition_hash = Arc::new(RepartitionExec::try_new( - repartition_rr, - Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), - )?) as _; - let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = vec![ + "SortPreservingMergeExec: [a@0 ASC]", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", + ]; - let expected_input = ["SortExec: expr=[a@0 ASC]", + // Expected bounded results with and without flag + let expected_optimized_bounded = vec![ + "SortExec: expr=[a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC]", - " SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, false); + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=true", + ]; + let expected_optimized_bounded_parallelize_sort = vec![ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=true", + ]; + let (expected_input, expected_optimized, expected_optimized_sort_parallelize) = + if source_unbounded { + ( + expected_input_unbounded, + expected_optimized_unbounded.clone(), + expected_optimized_unbounded, + ) + } else { + ( + expected_input_bounded, + expected_optimized_bounded, + expected_optimized_bounded_parallelize_sort, + ) + }; + assert_optimized!( + expected_input, + expected_optimized, + physical_plan.clone(), + false + ); + assert_optimized!( + expected_input, + expected_optimized_sort_parallelize, + physical_plan, + true + ); Ok(()) } @@ -2191,7 +2228,7 @@ mod tests { async fn test_do_not_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs.clone(), false); + let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); let physical_plan = sort_exec(vec![sort_expr("b", &schema)], spm); @@ -2212,7 +2249,7 @@ mod tests { async fn test_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs.clone(), false); + let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); let physical_plan = sort_exec( @@ -2235,4 +2272,36 @@ mod tests { assert_optimized!(expected_input, expected_optimized, physical_plan, false); Ok(()) } + + #[tokio::test] + async fn test_window_multi_layer_requirement() -> Result<()> { + let schema = create_test_schema3()?; + let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let source = csv_exec_sorted(&schema, vec![]); + let sort = sort_exec(sort_exprs.clone(), source); + let repartition = repartition_exec(sort); + let repartition = spr_repartition_exec(repartition); + let spm = sort_preserving_merge_exec(sort_exprs.clone(), repartition); + + let physical_plan = bounded_window_exec("a", sort_exprs, spm); + + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC,b@1 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortExec: expr=[a@0 ASC,b@1 ASC]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " SortExec: expr=[a@0 ASC,b@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, false); + Ok(()) + } } diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 4cff4a8f6c55..6b2fe24acf00 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -60,21 +60,29 @@ impl JoinSelection { // TODO: We need some performance test for Right Semi/Right Join swap to Left Semi/Left Join in case that the right side is smaller but not much smaller. // TODO: In PrestoSQL, the optimizer flips join sides only if one side is much smaller than the other by more than SIZE_DIFFERENCE_THRESHOLD times, by default is is 8 times. /// Checks statistics for join swap. -fn should_swap_join_order(left: &dyn ExecutionPlan, right: &dyn ExecutionPlan) -> bool { +fn should_swap_join_order( + left: &dyn ExecutionPlan, + right: &dyn ExecutionPlan, +) -> Result { // Get the left and right table's total bytes // If both the left and right tables contain total_byte_size statistics, // use `total_byte_size` to determine `should_swap_join_order`, else use `num_rows` - let (left_size, right_size) = match ( - left.statistics().total_byte_size, - right.statistics().total_byte_size, + let left_stats = left.statistics()?; + let right_stats = right.statistics()?; + // First compare `total_byte_size` of left and right side, + // if information in this field is insufficient fallback to the `num_rows` + match ( + left_stats.total_byte_size.get_value(), + right_stats.total_byte_size.get_value(), ) { - (Some(l), Some(r)) => (Some(l), Some(r)), - _ => (left.statistics().num_rows, right.statistics().num_rows), - }; - - match (left_size, right_size) { - (Some(l), Some(r)) => l > r, - _ => false, + (Some(l), Some(r)) => Ok(l > r), + _ => match ( + left_stats.num_rows.get_value(), + right_stats.num_rows.get_value(), + ) { + (Some(l), Some(r)) => Ok(l > r), + _ => Ok(false), + }, } } @@ -84,10 +92,14 @@ fn supports_collect_by_size( ) -> bool { // Currently we do not trust the 0 value from stats, due to stats collection might have bug // TODO check the logic in datasource::get_statistics_with_limit() - if let Some(size) = plan.statistics().total_byte_size { - size != 0 && size < collection_size_threshold - } else if let Some(row_count) = plan.statistics().num_rows { - row_count != 0 && row_count < collection_size_threshold + let Ok(stats) = plan.statistics() else { + return false; + }; + + if let Some(size) = stats.total_byte_size.get_value() { + *size != 0 && *size < collection_size_threshold + } else if let Some(row_count) = stats.num_rows.get_value() { + *row_count != 0 && *row_count < collection_size_threshold } else { false } @@ -294,7 +306,7 @@ fn try_collect_left( }; match (left_can_collect, right_can_collect) { (true, true) => { - if should_swap_join_order(&**left, &**right) + if should_swap_join_order(&**left, &**right)? && supports_swap(*hash_join.join_type()) { Ok(Some(swap_hash_join(hash_join, PartitionMode::CollectLeft)?)) @@ -333,7 +345,7 @@ fn try_collect_left( fn partitioned_hash_join(hash_join: &HashJoinExec) -> Result> { let left = hash_join.left(); let right = hash_join.right(); - if should_swap_join_order(&**left, &**right) && supports_swap(*hash_join.join_type()) + if should_swap_join_order(&**left, &**right)? && supports_swap(*hash_join.join_type()) { swap_hash_join(hash_join, PartitionMode::Partitioned) } else { @@ -373,7 +385,7 @@ fn statistical_join_selection_subrule( PartitionMode::Partitioned => { let left = hash_join.left(); let right = hash_join.right(); - if should_swap_join_order(&**left, &**right) + if should_swap_join_order(&**left, &**right)? && supports_swap(*hash_join.join_type()) { swap_hash_join(hash_join, PartitionMode::Partitioned).map(Some)? @@ -385,7 +397,7 @@ fn statistical_join_selection_subrule( } else if let Some(cross_join) = plan.as_any().downcast_ref::() { let left = cross_join.left(); let right = cross_join.right(); - if should_swap_join_order(&**left, &**right) { + if should_swap_join_order(&**left, &**right)? { let new_join = CrossJoinExec::new(Arc::clone(right), Arc::clone(left)); // TODO avoid adding ProjectionExec again and again, only adding Final Projection let proj: Arc = Arc::new(ProjectionExec::try_new( @@ -422,7 +434,7 @@ fn hash_join_convert_symmetric_subrule( config_options: &ConfigOptions, ) -> Option> { if let Some(hash_join) = input.plan.as_any().downcast_ref::() { - let ub_flags = &input.children_unbounded; + let ub_flags = input.children_unbounded(); let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); input.unbounded = left_unbounded || right_unbounded; let result = if left_unbounded && right_unbounded { @@ -499,7 +511,7 @@ fn hash_join_swap_subrule( _config_options: &ConfigOptions, ) -> Option> { if let Some(hash_join) = input.plan.as_any().downcast_ref::() { - let ub_flags = &input.children_unbounded; + let ub_flags = input.children_unbounded(); let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); input.unbounded = left_unbounded || right_unbounded; let result = if left_unbounded @@ -565,7 +577,7 @@ fn apply_subrules( } let is_unbounded = input .plan - .unbounded_output(&input.children_unbounded) + .unbounded_output(&input.children_unbounded()) // Treat the case where an operator can not run on unbounded data as // if it can and it outputs unbounded data. Do not raise an error yet. // Such operators may be fixed, adjusted or replaced later on during @@ -579,6 +591,8 @@ fn apply_subrules( #[cfg(test)] mod tests_statistical { + use std::sync::Arc; + use super::*; use crate::{ physical_plan::{ @@ -587,28 +601,26 @@ mod tests_statistical { test::StatisticsExec, }; - use std::sync::Arc; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{JoinType, ScalarValue}; + use datafusion_common::{stats::Precision, JoinType, ScalarValue}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::PhysicalExpr; fn create_big_and_small() -> (Arc, Arc) { let big = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(10), - total_byte_size: Some(100000), - ..Default::default() + num_rows: Precision::Inexact(10), + total_byte_size: Precision::Inexact(100000), + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("big_col", DataType::Int32, false)]), )); let small = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(100000), - total_byte_size: Some(10), - ..Default::default() + num_rows: Precision::Inexact(100000), + total_byte_size: Precision::Inexact(10), + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("small_col", DataType::Int32, false)]), )); @@ -624,13 +636,19 @@ mod tests_statistical { min: Option, max: Option, distinct_count: Option, - ) -> Option> { - Some(vec![ColumnStatistics { - distinct_count, - min_value: min.map(|size| ScalarValue::UInt64(Some(size))), - max_value: max.map(|size| ScalarValue::UInt64(Some(size))), + ) -> Vec { + vec![ColumnStatistics { + distinct_count: distinct_count + .map(Precision::Inexact) + .unwrap_or(Precision::Absent), + min_value: min + .map(|size| Precision::Inexact(ScalarValue::UInt64(Some(size)))) + .unwrap_or(Precision::Absent), + max_value: max + .map(|size| Precision::Inexact(ScalarValue::UInt64(Some(size)))) + .unwrap_or(Precision::Absent), ..Default::default() - }]) + }] } /// Returns three plans with statistics of (min, max, distinct_count) @@ -644,39 +662,39 @@ mod tests_statistical { ) { let big = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(100_000), + num_rows: Precision::Inexact(100_000), column_statistics: create_column_stats( Some(0), Some(50_000), Some(50_000), ), - ..Default::default() + total_byte_size: Precision::Absent, }, Schema::new(vec![Field::new("big_col", DataType::Int32, false)]), )); let medium = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(10_000), + num_rows: Precision::Inexact(10_000), column_statistics: create_column_stats( Some(1000), Some(5000), Some(1000), ), - ..Default::default() + total_byte_size: Precision::Absent, }, Schema::new(vec![Field::new("medium_col", DataType::Int32, false)]), )); let small = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(1000), + num_rows: Precision::Inexact(1000), column_statistics: create_column_stats( Some(0), Some(100_000), Some(1000), ), - ..Default::default() + total_byte_size: Precision::Absent, }, Schema::new(vec![Field::new("small_col", DataType::Int32, false)]), )); @@ -725,10 +743,13 @@ mod tests_statistical { .downcast_ref::() .expect("The type of the plan should not be changed"); - assert_eq!(swapped_join.left().statistics().total_byte_size, Some(10)); assert_eq!( - swapped_join.right().statistics().total_byte_size, - Some(100000) + swapped_join.left().statistics().unwrap().total_byte_size, + Precision::Inexact(10) + ); + assert_eq!( + swapped_join.right().statistics().unwrap().total_byte_size, + Precision::Inexact(100000) ); } @@ -774,10 +795,13 @@ mod tests_statistical { .expect("The type of the plan should not be changed"); assert_eq!( - swapped_join.left().statistics().total_byte_size, - Some(100000) + swapped_join.left().statistics().unwrap().total_byte_size, + Precision::Inexact(100000) + ); + assert_eq!( + swapped_join.right().statistics().unwrap().total_byte_size, + Precision::Inexact(10) ); - assert_eq!(swapped_join.right().statistics().total_byte_size, Some(10)); } #[tokio::test] @@ -815,10 +839,13 @@ mod tests_statistical { assert_eq!(swapped_join.schema().fields().len(), 1); - assert_eq!(swapped_join.left().statistics().total_byte_size, Some(10)); assert_eq!( - swapped_join.right().statistics().total_byte_size, - Some(100000) + swapped_join.left().statistics().unwrap().total_byte_size, + Precision::Inexact(10) + ); + assert_eq!( + swapped_join.right().statistics().unwrap().total_byte_size, + Precision::Inexact(100000) ); assert_eq!(original_schema, swapped_join.schema()); @@ -893,9 +920,9 @@ mod tests_statistical { " HashJoinExec: mode=CollectLeft, join_type=Right, on=[(small_col@1, medium_col@0)]", " ProjectionExec: expr=[big_col@1 as big_col, small_col@0 as small_col]", " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(small_col@0, big_col@0)]", - " StatisticsExec: col_count=1, row_count=Some(1000)", - " StatisticsExec: col_count=1, row_count=Some(100000)", - " StatisticsExec: col_count=1, row_count=Some(10000)", + " StatisticsExec: col_count=1, row_count=Inexact(1000)", + " StatisticsExec: col_count=1, row_count=Inexact(100000)", + " StatisticsExec: col_count=1, row_count=Inexact(10000)", "", ]; assert_optimized!(expected, join); @@ -927,10 +954,13 @@ mod tests_statistical { .downcast_ref::() .expect("The type of the plan should not be changed"); - assert_eq!(swapped_join.left().statistics().total_byte_size, Some(10)); assert_eq!( - swapped_join.right().statistics().total_byte_size, - Some(100000) + swapped_join.left().statistics().unwrap().total_byte_size, + Precision::Inexact(10) + ); + assert_eq!( + swapped_join.right().statistics().unwrap().total_byte_size, + Precision::Inexact(100000) ); } @@ -973,27 +1003,27 @@ mod tests_statistical { async fn test_join_selection_collect_left() { let big = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(10000000), - total_byte_size: Some(10000000), - ..Default::default() + num_rows: Precision::Inexact(10000000), + total_byte_size: Precision::Inexact(10000000), + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("big_col", DataType::Int32, false)]), )); let small = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(10), - total_byte_size: Some(10), - ..Default::default() + num_rows: Precision::Inexact(10), + total_byte_size: Precision::Inexact(10), + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("small_col", DataType::Int32, false)]), )); let empty = Arc::new(StatisticsExec::new( Statistics { - num_rows: None, - total_byte_size: None, - ..Default::default() + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("empty_col", DataType::Int32, false)]), )); @@ -1051,27 +1081,27 @@ mod tests_statistical { async fn test_join_selection_partitioned() { let big1 = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(10000000), - total_byte_size: Some(10000000), - ..Default::default() + num_rows: Precision::Inexact(10000000), + total_byte_size: Precision::Inexact(10000000), + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("big_col1", DataType::Int32, false)]), )); let big2 = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(20000000), - total_byte_size: Some(20000000), - ..Default::default() + num_rows: Precision::Inexact(20000000), + total_byte_size: Precision::Inexact(20000000), + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("big_col2", DataType::Int32, false)]), )); let empty = Arc::new(StatisticsExec::new( Statistics { - num_rows: None, - total_byte_size: None, - ..Default::default() + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("empty_col", DataType::Int32, false)]), )); @@ -1173,34 +1203,40 @@ mod tests_statistical { #[cfg(test)] mod util_tests { + use std::sync::Arc; + + use arrow_schema::{DataType, Field, Schema}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr}; use datafusion_physical_expr::intervals::utils::check_support; use datafusion_physical_expr::PhysicalExpr; - use std::sync::Arc; #[test] fn check_expr_supported() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ])); let supported_expr = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Plus, Arc::new(Column::new("a", 0)), )) as Arc; - assert!(check_support(&supported_expr)); + assert!(check_support(&supported_expr, &schema)); let supported_expr_2 = Arc::new(Column::new("a", 0)) as Arc; - assert!(check_support(&supported_expr_2)); + assert!(check_support(&supported_expr_2, &schema)); let unsupported_expr = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Or, Arc::new(Column::new("a", 0)), )) as Arc; - assert!(!check_support(&unsupported_expr)); + assert!(!check_support(&unsupported_expr, &schema)); let unsupported_expr_2 = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Or, Arc::new(NegativeExpr::new(Arc::new(Column::new("a", 0)))), )) as Arc; - assert!(!check_support(&unsupported_expr_2)); + assert!(!check_support(&unsupported_expr_2, &schema)); } } @@ -1217,6 +1253,7 @@ mod hash_join_tests { use arrow::record_batch::RecordBatch; use datafusion_common::utils::DataPtr; use datafusion_common::JoinType; + use datafusion_physical_plan::empty::EmptyExec; use std::sync::Arc; struct TestCase { @@ -1584,10 +1621,22 @@ mod hash_join_tests { false, )?; + let children = vec![ + PipelineStatePropagator { + plan: Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), + unbounded: left_unbounded, + children: vec![], + }, + PipelineStatePropagator { + plan: Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), + unbounded: right_unbounded, + children: vec![], + }, + ]; let initial_hash_join_state = PipelineStatePropagator { plan: Arc::new(join), unbounded: false, - children_unbounded: vec![left_unbounded, right_unbounded], + children, }; let optimized_hash_join = diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs new file mode 100644 index 000000000000..540f9a6a132b --- /dev/null +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -0,0 +1,609 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A special-case optimizer rule that pushes limit into a grouped aggregation +//! which has no aggregate expressions or sorting requirements + +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::aggregates::AggregateExec; +use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use crate::physical_plan::ExecutionPlan; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; +use itertools::Itertools; +use std::sync::Arc; + +/// An optimizer rule that passes a `limit` hint into grouped aggregations which don't require all +/// rows in the group to be processed for correctness. Example queries fitting this description are: +/// `SELECT distinct l_orderkey FROM lineitem LIMIT 10;` +/// `SELECT l_orderkey FROM lineitem GROUP BY l_orderkey LIMIT 10;` +pub struct LimitedDistinctAggregation {} + +impl LimitedDistinctAggregation { + /// Create a new `LimitedDistinctAggregation` + pub fn new() -> Self { + Self {} + } + + fn transform_agg( + aggr: &AggregateExec, + limit: usize, + ) -> Option> { + // rules for transforming this Aggregate are held in this method + if !aggr.is_unordered_unfiltered_group_by_distinct() { + return None; + } + + // We found what we want: clone, copy the limit down, and return modified node + let new_aggr = AggregateExec::try_new( + *aggr.mode(), + aggr.group_by().clone(), + aggr.aggr_expr().to_vec(), + aggr.filter_expr().to_vec(), + aggr.input().clone(), + aggr.input_schema(), + ) + .expect("Unable to copy Aggregate!") + .with_limit(Some(limit)); + Some(Arc::new(new_aggr)) + } + + /// transform_limit matches an `AggregateExec` as the child of a `LocalLimitExec` + /// or `GlobalLimitExec` and pushes the limit into the aggregation as a soft limit when + /// there is a group by, but no sorting, no aggregate expressions, and no filters in the + /// aggregation + fn transform_limit(plan: Arc) -> Option> { + let limit: usize; + let mut global_fetch: Option = None; + let mut global_skip: usize = 0; + let children: Vec>; + let mut is_global_limit = false; + if let Some(local_limit) = plan.as_any().downcast_ref::() { + limit = local_limit.fetch(); + children = local_limit.children(); + } else if let Some(global_limit) = plan.as_any().downcast_ref::() + { + global_fetch = global_limit.fetch(); + global_fetch?; + global_skip = global_limit.skip(); + // the aggregate must read at least fetch+skip number of rows + limit = global_fetch.unwrap() + global_skip; + children = global_limit.children(); + is_global_limit = true + } else { + return None; + } + let child = children.iter().exactly_one().ok()?; + // ensure there is no output ordering; can this rule be relaxed? + if plan.output_ordering().is_some() { + return None; + } + // ensure no ordering is required on the input + if plan.required_input_ordering()[0].is_some() { + return None; + } + + // if found_match_aggr is true, match_aggr holds a parent aggregation whose group_by + // must match that of a child aggregation in order to rewrite the child aggregation + let mut match_aggr: Arc = plan; + let mut found_match_aggr = false; + + let mut rewrite_applicable = true; + let mut closure = |plan: Arc| { + if !rewrite_applicable { + return Ok(Transformed::No(plan)); + } + if let Some(aggr) = plan.as_any().downcast_ref::() { + if found_match_aggr { + if let Some(parent_aggr) = + match_aggr.as_any().downcast_ref::() + { + if !parent_aggr.group_by().eq(aggr.group_by()) { + // a partial and final aggregation with different groupings disqualifies + // rewriting the child aggregation + rewrite_applicable = false; + return Ok(Transformed::No(plan)); + } + } + } + // either we run into an Aggregate and transform it, or disable the rewrite + // for subsequent children + match Self::transform_agg(aggr, limit) { + None => {} + Some(new_aggr) => { + match_aggr = plan; + found_match_aggr = true; + return Ok(Transformed::Yes(new_aggr)); + } + } + } + rewrite_applicable = false; + Ok(Transformed::No(plan)) + }; + let child = child.clone().transform_down_mut(&mut closure).ok()?; + if is_global_limit { + return Some(Arc::new(GlobalLimitExec::new( + child, + global_skip, + global_fetch, + ))); + } + Some(Arc::new(LocalLimitExec::new(child, limit))) + } +} + +impl Default for LimitedDistinctAggregation { + fn default() -> Self { + Self::new() + } +} + +impl PhysicalOptimizerRule for LimitedDistinctAggregation { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + let plan = if config.optimizer.enable_distinct_aggregation_soft_limit { + plan.transform_down(&|plan| { + Ok( + if let Some(plan) = + LimitedDistinctAggregation::transform_limit(plan.clone()) + { + Transformed::Yes(plan) + } else { + Transformed::No(plan) + }, + ) + })? + } else { + plan + }; + Ok(plan) + } + + fn name(&self) -> &str { + "LimitedDistinctAggregation" + } + + fn schema_check(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use crate::physical_optimizer::aggregate_statistics::tests::TestAggregate; + use crate::physical_optimizer::enforce_distribution::tests::{ + parquet_exec_with_sort, schema, trim_plan_display, + }; + use crate::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; + use crate::physical_plan::collect; + use crate::physical_plan::memory::MemoryExec; + use crate::prelude::SessionContext; + use arrow::array::Int32Array; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use arrow::util::pretty::pretty_format_batches; + use arrow_schema::SchemaRef; + use datafusion_execution::config::SessionConfig; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::cast; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr::{expressions, PhysicalExpr}; + use datafusion_physical_plan::aggregates::AggregateMode; + use datafusion_physical_plan::displayable; + use std::sync::Arc; + + fn mock_data() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(1), + Some(4), + Some(5), + ])), + Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(6), + Some(2), + Some(8), + Some(9), + ])), + ], + )?; + + Ok(Arc::new(MemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?)) + } + + fn assert_plan_matches_expected( + plan: &Arc, + expected: &[&str], + ) -> Result<()> { + let expected_lines: Vec<&str> = expected.to_vec(); + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + let optimized = LimitedDistinctAggregation::new() + .optimize(Arc::clone(plan), state.config_options())?; + + let optimized_result = displayable(optimized.as_ref()).indent(true).to_string(); + let actual_lines = trim_plan_display(&optimized_result); + + assert_eq!( + &expected_lines, &actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + + Ok(()) + } + + async fn assert_results_match_expected( + plan: Arc, + expected: &str, + ) -> Result<()> { + let cfg = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(cfg); + let batches = collect(plan, ctx.task_ctx()).await?; + let actual = format!("{}", pretty_format_batches(&batches)?); + assert_eq!(actual, expected); + Ok(()) + } + + pub fn build_group_by( + input_schema: &SchemaRef, + columns: Vec, + ) -> PhysicalGroupBy { + let mut group_by_expr: Vec<(Arc, String)> = vec![]; + for column in columns.iter() { + group_by_expr.push((col(column, input_schema).unwrap(), column.to_string())); + } + PhysicalGroupBy::new_single(group_by_expr.clone()) + } + + #[tokio::test] + async fn test_partial_final() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Partial/Final AggregateExec + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + Arc::new(partial_agg), /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(final_agg), + 4, // fetch + ); + // expected to push the limit to the Partial and Final AggregateExecs + let expected = [ + "LocalLimitExec: fetch=4", + "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[], lim=[4]", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[], lim=[4]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 1 | +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[tokio::test] + async fn test_single_local() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 4, // fetch + ); + // expected to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=4", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 1 | +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[tokio::test] + async fn test_single_global() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = GlobalLimitExec::new( + Arc::new(single_agg), + 1, // skip + Some(3), // fetch + ); + // expected to push the skip+fetch limit to the AggregateExec + let expected = [ + "GlobalLimitExec: skip=1, fetch=3", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[tokio::test] + async fn test_distinct_cols_different_than_group_by_cols() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT distinct a FROM MemoryExec GROUP BY a, b LIMIT 4;`, Single/Single AggregateExec + let group_by_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string(), "b".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let distinct_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + Arc::new(group_by_agg), /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(distinct_agg), + 4, // fetch + ); + // expected to push the limit to the outer AggregateExec only + let expected = [ + "LocalLimitExec: fetch=4", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", + "AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 1 | +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[test] + fn test_no_group_by() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT FROM MemoryExec LIMIT 10;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec![]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[], aggr=[]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } + + #[test] + fn test_has_aggregate_expression() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + // `SELECT FROM MemoryExec LIMIT 10;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![agg.count_expr()], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } + + #[test] + fn test_has_filter() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec + // the `a > 1` filter is applied in the AggregateExec + let filter_expr = Some(expressions::binary( + expressions::col("a", &schema)?, + Operator::Gt, + cast(expressions::lit(1u32), &schema, DataType::Int32)?, + &schema, + )?); + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![filter_expr], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + // TODO(msirek): open an issue for `filter_expr` of `AggregateExec` not printing out + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } + + #[test] + fn test_has_order_by() -> Result<()> { + let sort_key = vec![PhysicalSortExpr { + expr: expressions::col("a", &schema()).unwrap(), + options: SortOptions::default(), + }]; + let source = parquet_exec_with_sort(vec![sort_key]); + let schema = source.schema(); + + // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec + // the `a > 1` filter is applied in the AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], ordering_mode=Sorted", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index 0801a9bc595c..e990fead610d 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -27,8 +27,11 @@ pub mod combine_partial_final_agg; pub mod enforce_distribution; pub mod enforce_sorting; pub mod join_selection; +pub mod limited_distinct_aggregation; pub mod optimizer; +pub mod output_requirements; pub mod pipeline_checker; +mod projection_pushdown; pub mod pruning; pub mod replace_with_order_preserving_variants; mod sort_pushdown; diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index 5de70efe3c47..f8c82576e254 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -19,6 +19,7 @@ use std::sync::Arc; +use super::projection_pushdown::ProjectionPushdown; use crate::config::ConfigOptions; use crate::physical_optimizer::aggregate_statistics::AggregateStatistics; use crate::physical_optimizer::coalesce_batches::CoalesceBatches; @@ -26,6 +27,8 @@ use crate::physical_optimizer::combine_partial_final_agg::CombinePartialFinalAgg use crate::physical_optimizer::enforce_distribution::EnforceDistribution; use crate::physical_optimizer::enforce_sorting::EnforceSorting; use crate::physical_optimizer::join_selection::JoinSelection; +use crate::physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation; +use crate::physical_optimizer::output_requirements::OutputRequirements; use crate::physical_optimizer::pipeline_checker::PipelineChecker; use crate::physical_optimizer::topk_aggregation::TopKAggregation; use crate::{error::Result, physical_plan::ExecutionPlan}; @@ -68,6 +71,9 @@ impl PhysicalOptimizer { /// Create a new optimizer using the recommended list of rules pub fn new() -> Self { let rules: Vec> = vec![ + // If there is a output requirement of the query, make sure that + // this information is not lost across different rules during optimization. + Arc::new(OutputRequirements::new_add_mode()), Arc::new(AggregateStatistics::new()), // Statistics-based join selection will change the Auto mode to a real join implementation, // like collect left, or hash join, or future sort merge join, which will influence the @@ -75,6 +81,10 @@ impl PhysicalOptimizer { // repartitioning and local sorting steps to meet distribution and ordering requirements. // Therefore, it should run before EnforceDistribution and EnforceSorting. Arc::new(JoinSelection::new()), + // The LimitedDistinctAggregation rule should be applied before the EnforceDistribution rule, + // as that rule may inject other operations in between the different AggregateExecs. + // Applying the rule early means only directly-connected AggregateExecs must be examined. + Arc::new(LimitedDistinctAggregation::new()), // The EnforceDistribution rule is for adding essential repartitioning to satisfy distribution // requirements. Please make sure that the whole plan tree is determined before this rule. // This rule increases parallelism if doing so is beneficial to the physical plan; i.e. at @@ -90,6 +100,9 @@ impl PhysicalOptimizer { // The CoalesceBatches rule will not influence the distribution and ordering of the // whole plan tree. Therefore, to avoid influencing other rules, it should run last. Arc::new(CoalesceBatches::new()), + // Remove the ancillary output requirement operator since we are done with the planning + // phase. + Arc::new(OutputRequirements::new_remove_mode()), // The PipelineChecker rule will reject non-runnable query plans that use // pipeline-breaking operators on infinite input(s). The rule generates a // diagnostic error message when this happens. It makes no changes to the @@ -100,6 +113,13 @@ impl PhysicalOptimizer { // into an `order by max(x) limit y`. In this case it will copy the limit value down // to the aggregation, allowing it to use only y number of accumulators. Arc::new(TopKAggregation::new()), + // The ProjectionPushdown rule tries to push projections towards + // the sources in the execution plan. As a result of this process, + // a projection can disappear if it reaches the source providers, and + // sequential projections can merge into one. Even if these two cases + // are not present, the load of executors such as join or union will be + // reduced by narrowing their input tables. + Arc::new(ProjectionPushdown::new()), ]; Self::with_rules(rules) diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs new file mode 100644 index 000000000000..4d03840d3dd3 --- /dev/null +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -0,0 +1,279 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! The GlobalOrderRequire optimizer rule either: +//! - Adds an auxiliary `OutputRequirementExec` operator to keep track of global +//! ordering and distribution requirement across rules, or +//! - Removes the auxiliary `OutputRequirementExec` operator from the physical plan. +//! Since the `OutputRequirementExec` operator is only a helper operator, it +//! shouldn't occur in the final plan (i.e. the executed plan). + +use std::sync::Arc; + +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; + +use arrow_schema::SchemaRef; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{Result, Statistics}; +use datafusion_physical_expr::{ + Distribution, LexRequirement, PhysicalSortExpr, PhysicalSortRequirement, +}; +use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; + +/// This rule either adds or removes [`OutputRequirements`]s to/from the physical +/// plan according to its `mode` attribute, which is set by the constructors +/// `new_add_mode` and `new_remove_mode`. With this rule, we can keep track of +/// the global requirements (ordering and distribution) across rules. +/// +/// The primary usecase of this node and rule is to specify and preserve the desired output +/// ordering and distribution the entire plan. When sending to a single client, a single partition may +/// be desirable, but when sending to a multi-partitioned writer, keeping multiple partitions may be +/// better. +#[derive(Debug)] +pub struct OutputRequirements { + mode: RuleMode, +} + +impl OutputRequirements { + /// Create a new rule which works in `Add` mode; i.e. it simply adds a + /// top-level [`OutputRequirementExec`] into the physical plan to keep track + /// of global ordering and distribution requirements if there are any. + /// Note that this rule should run at the beginning. + pub fn new_add_mode() -> Self { + Self { + mode: RuleMode::Add, + } + } + + /// Create a new rule which works in `Remove` mode; i.e. it simply removes + /// the top-level [`OutputRequirementExec`] from the physical plan if there is + /// any. We do this because a `OutputRequirementExec` is an ancillary, + /// non-executable operator whose sole purpose is to track global + /// requirements during optimization. Therefore, a + /// `OutputRequirementExec` should not appear in the final plan. + pub fn new_remove_mode() -> Self { + Self { + mode: RuleMode::Remove, + } + } +} + +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Hash)] +enum RuleMode { + Add, + Remove, +} + +/// An ancillary, non-executable operator whose sole purpose is to track global +/// requirements during optimization. It imposes +/// - the ordering requirement in its `order_requirement` attribute. +/// - the distribution requirement in its `dist_requirement` attribute. +/// +/// See [`OutputRequirements`] for more details +#[derive(Debug)] +pub(crate) struct OutputRequirementExec { + input: Arc, + order_requirement: Option, + dist_requirement: Distribution, +} + +impl OutputRequirementExec { + pub(crate) fn new( + input: Arc, + requirements: Option, + dist_requirement: Distribution, + ) -> Self { + Self { + input, + order_requirement: requirements, + dist_requirement, + } + } + + pub(crate) fn input(&self) -> Arc { + self.input.clone() + } +} + +impl DisplayAs for OutputRequirementExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "OutputRequirementExec") + } +} + +impl ExecutionPlan for OutputRequirementExec { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.input.schema() + } + + fn output_partitioning(&self) -> crate::physical_plan::Partitioning { + self.input.output_partitioning() + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + + fn required_input_distribution(&self) -> Vec { + vec![self.dist_requirement.clone()] + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.input.output_ordering() + } + + fn maintains_input_order(&self) -> Vec { + vec![true] + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + fn required_input_ordering(&self) -> Vec>> { + vec![self.order_requirement.clone()] + } + + fn unbounded_output(&self, children: &[bool]) -> Result { + // Has a single child + Ok(children[0]) + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::new( + children.remove(0), // has a single child + self.order_requirement.clone(), + self.dist_requirement.clone(), + ))) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unreachable!(); + } + + fn statistics(&self) -> Result { + self.input.statistics() + } +} + +impl PhysicalOptimizerRule for OutputRequirements { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + match self.mode { + RuleMode::Add => require_top_ordering(plan), + RuleMode::Remove => plan.transform_up(&|plan| { + if let Some(sort_req) = + plan.as_any().downcast_ref::() + { + Ok(Transformed::Yes(sort_req.input())) + } else { + Ok(Transformed::No(plan)) + } + }), + } + } + + fn name(&self) -> &str { + "OutputRequirements" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// This functions adds ancillary `OutputRequirementExec` to the the physical plan, so that +/// global requirements are not lost during optimization. +fn require_top_ordering(plan: Arc) -> Result> { + let (new_plan, is_changed) = require_top_ordering_helper(plan)?; + if is_changed { + Ok(new_plan) + } else { + // Add `OutputRequirementExec` to the top, with no specified ordering and distribution requirement. + Ok(Arc::new(OutputRequirementExec::new( + new_plan, + // there is no ordering requirement + None, + Distribution::UnspecifiedDistribution, + )) as _) + } +} + +/// Helper function that adds an ancillary `OutputRequirementExec` to the given plan. +/// First entry in the tuple is resulting plan, second entry indicates whether any +/// `OutputRequirementExec` is added to the plan. +fn require_top_ordering_helper( + plan: Arc, +) -> Result<(Arc, bool)> { + let mut children = plan.children(); + // Global ordering defines desired ordering in the final result. + if children.len() != 1 { + Ok((plan, false)) + } else if let Some(sort_exec) = plan.as_any().downcast_ref::() { + let req_ordering = sort_exec.output_ordering().unwrap_or(&[]); + let req_dist = sort_exec.required_input_distribution()[0].clone(); + let reqs = PhysicalSortRequirement::from_sort_exprs(req_ordering); + Ok(( + Arc::new(OutputRequirementExec::new(plan, Some(reqs), req_dist)) as _, + true, + )) + } else if let Some(spm) = plan.as_any().downcast_ref::() { + let reqs = PhysicalSortRequirement::from_sort_exprs(spm.expr()); + Ok(( + Arc::new(OutputRequirementExec::new( + plan, + Some(reqs), + Distribution::SinglePartition, + )) as _, + true, + )) + } else if plan.maintains_input_order()[0] + && plan.required_input_ordering()[0].is_none() + { + // Keep searching for a `SortExec` as long as ordering is maintained, + // and on-the-way operators do not themselves require an ordering. + // When an operator requires an ordering, any `SortExec` below can not + // be responsible for (i.e. the originator of) the global ordering. + let (new_child, is_changed) = + require_top_ordering_helper(children.swap_remove(0))?; + Ok((plan.with_new_children(vec![new_child])?, is_changed)) + } else { + // Stop searching, there is no global ordering desired for the query. + Ok((plan, false)) + } +} diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 44679647b5b2..e281d0e7c23e 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -18,17 +18,20 @@ //! The [PipelineChecker] rule ensures that a given plan can accommodate its //! infinite sources, if there are any. It will reject non-runnable query plans //! that use pipeline-breaking operators on infinite input(s). -//! + +use std::borrow::Cow; +use std::sync::Arc; + use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::joins::SymmetricHashJoinExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; + use datafusion_common::config::OptimizerOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; -use std::sync::Arc; +use datafusion_physical_plan::joins::SymmetricHashJoinExec; /// The PipelineChecker rule rejects non-runnable query plans that use /// pipeline-breaking operators on infinite input(s). @@ -68,65 +71,48 @@ impl PhysicalOptimizerRule for PipelineChecker { pub struct PipelineStatePropagator { pub(crate) plan: Arc, pub(crate) unbounded: bool, - pub(crate) children_unbounded: Vec, + pub(crate) children: Vec, } impl PipelineStatePropagator { /// Constructs a new, default pipelining state. pub fn new(plan: Arc) -> Self { - let length = plan.children().len(); - PipelineStatePropagator { + let children = plan.children(); + Self { plan, unbounded: false, - children_unbounded: vec![false; length], + children: children.into_iter().map(Self::new).collect(), } } + + /// Returns the children unboundedness information. + pub fn children_unbounded(&self) -> Vec { + self.children.iter().map(|c| c.unbounded).collect() + } } impl TreeNode for PipelineStatePropagator { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = self.plan.children(); - for child in children { - match op(&PipelineStatePropagator::new(child))? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children.iter().map(Cow::Borrowed).collect() } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.plan.children(); - if !children.is_empty() { - let new_children = children + if !self.children.is_empty() { + self.children = self + .children .into_iter() - .map(|child| PipelineStatePropagator::new(child)) .map(transform) - .collect::>>()?; - let children_unbounded = new_children - .iter() - .map(|c| c.unbounded) - .collect::>(); - let children_plans = new_children - .into_iter() - .map(|child| child.plan) - .collect::>(); - Ok(PipelineStatePropagator { - plan: with_new_children_if_necessary(self.plan, children_plans)?.into(), - unbounded: self.unbounded, - children_unbounded, - }) - } else { - Ok(self) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } @@ -147,7 +133,7 @@ pub fn check_finiteness_requirements( } input .plan - .unbounded_output(&input.children_unbounded) + .unbounded_output(&input.children_unbounded()) .map(|value| { input.unbounded = value; Transformed::Yes(input) @@ -163,7 +149,7 @@ pub fn check_finiteness_requirements( /// [`Operator`]: datafusion_expr::Operator fn is_prunable(join: &SymmetricHashJoinExec) -> bool { join.filter().map_or(false, |filter| { - check_support(filter.expression()) + check_support(filter.expression(), &join.schema()) && filter .schema() .fields() diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs new file mode 100644 index 000000000000..d237a3e8607e --- /dev/null +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -0,0 +1,2313 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This file implements the `ProjectionPushdown` physical optimization rule. +//! The function [`remove_unnecessary_projections`] tries to push down all +//! projections one by one if the operator below is amenable to this. If a +//! projection reaches a source, it can even dissappear from the plan entirely. + +use std::collections::HashMap; +use std::sync::Arc; + +use super::output_requirements::OutputRequirementExec; +use super::PhysicalOptimizerRule; +use crate::datasource::physical_plan::CsvExec; +use crate::error::Result; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::filter::FilterExec; +use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; +use crate::physical_plan::joins::{ + CrossJoinExec, HashJoinExec, NestedLoopJoinExec, SortMergeJoinExec, + SymmetricHashJoinExec, +}; +use crate::physical_plan::memory::MemoryExec; +use crate::physical_plan::projection::ProjectionExec; +use crate::physical_plan::repartition::RepartitionExec; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use crate::physical_plan::{Distribution, ExecutionPlan}; + +use arrow_schema::SchemaRef; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::JoinSide; +use datafusion_physical_expr::expressions::{Column, Literal}; +use datafusion_physical_expr::{ + Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, +}; +use datafusion_physical_plan::streaming::StreamingTableExec; +use datafusion_physical_plan::union::UnionExec; + +use itertools::Itertools; + +/// This rule inspects [`ProjectionExec`]'s in the given physical plan and tries to +/// remove or swap with its child. +#[derive(Default)] +pub struct ProjectionPushdown {} + +impl ProjectionPushdown { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for ProjectionPushdown { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_down(&remove_unnecessary_projections) + } + + fn name(&self) -> &str { + "ProjectionPushdown" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// This function checks if `plan` is a [`ProjectionExec`], and inspects its +/// input(s) to test whether it can push `plan` under its input(s). This function +/// will operate on the entire tree and may ultimately remove `plan` entirely +/// by leveraging source providers with built-in projection capabilities. +pub fn remove_unnecessary_projections( + plan: Arc, +) -> Result>> { + let maybe_modified = if let Some(projection) = + plan.as_any().downcast_ref::() + { + // If the projection does not cause any change on the input, we can + // safely remove it: + if is_projection_removable(projection) { + return Ok(Transformed::Yes(projection.input().clone())); + } + // If it does, check if we can push it under its child(ren): + let input = projection.input().as_any(); + if let Some(csv) = input.downcast_ref::() { + try_swapping_with_csv(projection, csv) + } else if let Some(memory) = input.downcast_ref::() { + try_swapping_with_memory(projection, memory)? + } else if let Some(child_projection) = input.downcast_ref::() { + let maybe_unified = try_unifying_projections(projection, child_projection)?; + return if let Some(new_plan) = maybe_unified { + // To unify 3 or more sequential projections: + remove_unnecessary_projections(new_plan) + } else { + Ok(Transformed::No(plan)) + }; + } else if let Some(output_req) = input.downcast_ref::() { + try_swapping_with_output_req(projection, output_req)? + } else if input.is::() { + try_swapping_with_coalesce_partitions(projection)? + } else if let Some(filter) = input.downcast_ref::() { + try_swapping_with_filter(projection, filter)? + } else if let Some(repartition) = input.downcast_ref::() { + try_swapping_with_repartition(projection, repartition)? + } else if let Some(sort) = input.downcast_ref::() { + try_swapping_with_sort(projection, sort)? + } else if let Some(spm) = input.downcast_ref::() { + try_swapping_with_sort_preserving_merge(projection, spm)? + } else if let Some(union) = input.downcast_ref::() { + try_pushdown_through_union(projection, union)? + } else if let Some(hash_join) = input.downcast_ref::() { + try_pushdown_through_hash_join(projection, hash_join)? + } else if let Some(cross_join) = input.downcast_ref::() { + try_swapping_with_cross_join(projection, cross_join)? + } else if let Some(nl_join) = input.downcast_ref::() { + try_swapping_with_nested_loop_join(projection, nl_join)? + } else if let Some(sm_join) = input.downcast_ref::() { + try_swapping_with_sort_merge_join(projection, sm_join)? + } else if let Some(sym_join) = input.downcast_ref::() { + try_swapping_with_sym_hash_join(projection, sym_join)? + } else if let Some(ste) = input.downcast_ref::() { + try_swapping_with_streaming_table(projection, ste)? + } else { + // If the input plan of the projection is not one of the above, we + // conservatively assume that pushing the projection down may hurt. + // When adding new operators, consider adding them here if you + // think pushing projections under them is beneficial. + None + } + } else { + return Ok(Transformed::No(plan)); + }; + + Ok(maybe_modified.map_or(Transformed::No(plan), Transformed::Yes)) +} + +/// Tries to embed `projection` to its input (`csv`). If possible, returns +/// [`CsvExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_csv( + projection: &ProjectionExec, + csv: &CsvExec, +) -> Option> { + // If there is any non-column or alias-carrier expression, Projection should not be removed. + // This process can be moved into CsvExec, but it would be an overlap of their responsibility. + all_alias_free_columns(projection.expr()).then(|| { + let mut file_scan = csv.base_config().clone(); + let new_projections = + new_projections_for_columns(projection, &file_scan.projection); + file_scan.projection = Some(new_projections); + + Arc::new(CsvExec::new( + file_scan, + csv.has_header(), + csv.delimiter(), + csv.quote(), + csv.escape(), + csv.file_compression_type, + )) as _ + }) +} + +/// Tries to embed `projection` to its input (`memory`). If possible, returns +/// [`MemoryExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_memory( + projection: &ProjectionExec, + memory: &MemoryExec, +) -> Result>> { + // If there is any non-column or alias-carrier expression, Projection should not be removed. + // This process can be moved into MemoryExec, but it would be an overlap of their responsibility. + all_alias_free_columns(projection.expr()) + .then(|| { + let new_projections = + new_projections_for_columns(projection, memory.projection()); + + MemoryExec::try_new( + memory.partitions(), + memory.original_schema(), + Some(new_projections), + ) + .map(|e| Arc::new(e) as _) + }) + .transpose() +} + +/// Tries to embed `projection` to its input (`streaming table`). +/// If possible, returns [`StreamingTableExec`] as the top plan. Otherwise, +/// returns `None`. +fn try_swapping_with_streaming_table( + projection: &ProjectionExec, + streaming_table: &StreamingTableExec, +) -> Result>> { + if !all_alias_free_columns(projection.expr()) { + return Ok(None); + } + + let streaming_table_projections = streaming_table + .projection() + .as_ref() + .map(|i| i.as_ref().to_vec()); + let new_projections = + new_projections_for_columns(projection, &streaming_table_projections); + + let mut lex_orderings = vec![]; + for lex_ordering in streaming_table.projected_output_ordering().into_iter() { + let mut orderings = vec![]; + for order in lex_ordering { + let Some(new_ordering) = update_expr(&order.expr, projection.expr(), false)? + else { + return Ok(None); + }; + orderings.push(PhysicalSortExpr { + expr: new_ordering, + options: order.options, + }); + } + lex_orderings.push(orderings); + } + + StreamingTableExec::try_new( + streaming_table.partition_schema().clone(), + streaming_table.partitions().clone(), + Some(&new_projections), + lex_orderings, + streaming_table.is_infinite(), + ) + .map(|e| Some(Arc::new(e) as _)) +} + +/// Unifies `projection` with its input (which is also a [`ProjectionExec`]). +fn try_unifying_projections( + projection: &ProjectionExec, + child: &ProjectionExec, +) -> Result>> { + let mut projected_exprs = vec![]; + let mut column_ref_map: HashMap = HashMap::new(); + + // Collect the column references usage in the outer projection. + projection.expr().iter().for_each(|(expr, _)| { + expr.apply(&mut |expr| { + Ok({ + if let Some(column) = expr.as_any().downcast_ref::() { + *column_ref_map.entry(column.clone()).or_default() += 1; + } + VisitRecursion::Continue + }) + }) + .unwrap(); + }); + + // Merging these projections is not beneficial, e.g + // If an expression is not trivial and it is referred more than 1, unifies projections will be + // beneficial as caching mechanism for non-trivial computations. + // See discussion in: https://github.com/apache/arrow-datafusion/issues/8296 + if column_ref_map.iter().any(|(column, count)| { + *count > 1 && !is_expr_trivial(&child.expr()[column.index()].0.clone()) + }) { + return Ok(None); + } + + for (expr, alias) in projection.expr() { + // If there is no match in the input projection, we cannot unify these + // projections. This case will arise if the projection expression contains + // a `PhysicalExpr` variant `update_expr` doesn't support. + let Some(expr) = update_expr(expr, child.expr(), true)? else { + return Ok(None); + }; + projected_exprs.push((expr, alias.clone())); + } + + ProjectionExec::try_new(projected_exprs, child.input().clone()) + .map(|e| Some(Arc::new(e) as _)) +} + +/// Checks if the given expression is trivial. +/// An expression is considered trivial if it is either a `Column` or a `Literal`. +fn is_expr_trivial(expr: &Arc) -> bool { + expr.as_any().downcast_ref::().is_some() + || expr.as_any().downcast_ref::().is_some() +} + +/// Tries to swap `projection` with its input (`output_req`). If possible, +/// performs the swap and returns [`OutputRequirementExec`] as the top plan. +/// Otherwise, returns `None`. +fn try_swapping_with_output_req( + projection: &ProjectionExec, + output_req: &OutputRequirementExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down: + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let mut updated_sort_reqs = vec![]; + // None or empty_vec can be treated in the same way. + if let Some(reqs) = &output_req.required_input_ordering()[0] { + for req in reqs { + let Some(new_expr) = update_expr(&req.expr, projection.expr(), false)? else { + return Ok(None); + }; + updated_sort_reqs.push(PhysicalSortRequirement { + expr: new_expr, + options: req.options, + }); + } + } + + let dist_req = match &output_req.required_input_distribution()[0] { + Distribution::HashPartitioned(exprs) => { + let mut updated_exprs = vec![]; + for expr in exprs { + let Some(new_expr) = update_expr(expr, projection.expr(), false)? else { + return Ok(None); + }; + updated_exprs.push(new_expr); + } + Distribution::HashPartitioned(updated_exprs) + } + dist => dist.clone(), + }; + + make_with_child(projection, &output_req.input()) + .map(|input| { + OutputRequirementExec::new( + input, + (!updated_sort_reqs.is_empty()).then_some(updated_sort_reqs), + dist_req, + ) + }) + .map(|e| Some(Arc::new(e) as _)) +} + +/// Tries to swap `projection` with its input, which is known to be a +/// [`CoalescePartitionsExec`]. If possible, performs the swap and returns +/// [`CoalescePartitionsExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_coalesce_partitions( + projection: &ProjectionExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down: + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + // CoalescePartitionsExec always has a single child, so zero indexing is safe. + make_with_child(projection, &projection.input().children()[0]) + .map(|e| Some(Arc::new(CoalescePartitionsExec::new(e)) as _)) +} + +/// Tries to swap `projection` with its input (`filter`). If possible, performs +/// the swap and returns [`FilterExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_filter( + projection: &ProjectionExec, + filter: &FilterExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down: + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + // Each column in the predicate expression must exist after the projection. + let Some(new_predicate) = update_expr(filter.predicate(), projection.expr(), false)? + else { + return Ok(None); + }; + + FilterExec::try_new(new_predicate, make_with_child(projection, filter.input())?) + .and_then(|e| { + let selectivity = filter.default_selectivity(); + e.with_default_selectivity(selectivity) + }) + .map(|e| Some(Arc::new(e) as _)) +} + +/// Tries to swap the projection with its input [`RepartitionExec`]. If it can be done, +/// it returns the new swapped version having the [`RepartitionExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_repartition( + projection: &ProjectionExec, + repartition: &RepartitionExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + // If pushdown is not beneficial or applicable, break it. + if projection.benefits_from_input_partitioning()[0] || !all_columns(projection.expr()) + { + return Ok(None); + } + + let new_projection = make_with_child(projection, repartition.input())?; + + let new_partitioning = match repartition.partitioning() { + Partitioning::Hash(partitions, size) => { + let mut new_partitions = vec![]; + for partition in partitions { + let Some(new_partition) = + update_expr(partition, projection.expr(), false)? + else { + return Ok(None); + }; + new_partitions.push(new_partition); + } + Partitioning::Hash(new_partitions, *size) + } + others => others.clone(), + }; + + Ok(Some(Arc::new(RepartitionExec::try_new( + new_projection, + new_partitioning, + )?))) +} + +/// Tries to swap the projection with its input [`SortExec`]. If it can be done, +/// it returns the new swapped version having the [`SortExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_sort( + projection: &ProjectionExec, + sort: &SortExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let mut updated_exprs = vec![]; + for sort in sort.expr() { + let Some(new_expr) = update_expr(&sort.expr, projection.expr(), false)? else { + return Ok(None); + }; + updated_exprs.push(PhysicalSortExpr { + expr: new_expr, + options: sort.options, + }); + } + + Ok(Some(Arc::new( + SortExec::new(updated_exprs, make_with_child(projection, sort.input())?) + .with_fetch(sort.fetch()) + .with_preserve_partitioning(sort.preserve_partitioning()), + ))) +} + +/// Tries to swap the projection with its input [`SortPreservingMergeExec`]. +/// If this is possible, it returns the new [`SortPreservingMergeExec`] whose +/// child is a projection. Otherwise, it returns None. +fn try_swapping_with_sort_preserving_merge( + projection: &ProjectionExec, + spm: &SortPreservingMergeExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let mut updated_exprs = vec![]; + for sort in spm.expr() { + let Some(updated_expr) = update_expr(&sort.expr, projection.expr(), false)? + else { + return Ok(None); + }; + updated_exprs.push(PhysicalSortExpr { + expr: updated_expr, + options: sort.options, + }); + } + + Ok(Some(Arc::new( + SortPreservingMergeExec::new( + updated_exprs, + make_with_child(projection, spm.input())?, + ) + .with_fetch(spm.fetch()), + ))) +} + +/// Tries to push `projection` down through `union`. If possible, performs the +/// pushdown and returns a new [`UnionExec`] as the top plan which has projections +/// as its children. Otherwise, returns `None`. +fn try_pushdown_through_union( + projection: &ProjectionExec, + union: &UnionExec, +) -> Result>> { + // If the projection doesn't narrow the schema, we shouldn't try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let new_children = union + .children() + .into_iter() + .map(|child| make_with_child(projection, &child)) + .collect::>>()?; + + Ok(Some(Arc::new(UnionExec::new(new_children)))) +} + +/// Tries to push `projection` down through `hash_join`. If possible, performs the +/// pushdown and returns a new [`HashJoinExec`] as the top plan which has projections +/// as its children. Otherwise, returns `None`. +fn try_pushdown_through_hash_join( + projection: &ProjectionExec, + hash_join: &HashJoinExec, +) -> Result>> { + // Convert projected expressions to columns. We can not proceed if this is + // not possible. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + hash_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + hash_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let Some(new_on) = update_join_on( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + hash_join.on(), + ) else { + return Ok(None); + }; + + let new_filter = if let Some(filter) = hash_join.filter() { + match update_join_filter( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + filter, + hash_join.left(), + hash_join.right(), + ) { + Some(updated_filter) => Some(updated_filter), + None => return Ok(None), + } + } else { + None + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + hash_join.left(), + hash_join.right(), + )?; + + Ok(Some(Arc::new(HashJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_on, + new_filter, + hash_join.join_type(), + *hash_join.partition_mode(), + hash_join.null_equals_null, + )?))) +} + +/// Tries to swap the projection with its input [`CrossJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`CrossJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_cross_join( + projection: &ProjectionExec, + cross_join: &CrossJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + cross_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + cross_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + cross_join.left(), + cross_join.right(), + )?; + + Ok(Some(Arc::new(CrossJoinExec::new( + Arc::new(new_left), + Arc::new(new_right), + )))) +} + +/// Tries to swap the projection with its input [`NestedLoopJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`NestedLoopJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_nested_loop_join( + projection: &ProjectionExec, + nl_join: &NestedLoopJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + nl_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + nl_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let new_filter = if let Some(filter) = nl_join.filter() { + match update_join_filter( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + filter, + nl_join.left(), + nl_join.right(), + ) { + Some(updated_filter) => Some(updated_filter), + None => return Ok(None), + } + } else { + None + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + nl_join.left(), + nl_join.right(), + )?; + + Ok(Some(Arc::new(NestedLoopJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_filter, + nl_join.join_type(), + )?))) +} + +/// Tries to swap the projection with its input [`SortMergeJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`SortMergeJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_sort_merge_join( + projection: &ProjectionExec, + sm_join: &SortMergeJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + sm_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + sm_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let Some(new_on) = update_join_on( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + sm_join.on(), + ) else { + return Ok(None); + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + &sm_join.children()[0], + &sm_join.children()[1], + )?; + + Ok(Some(Arc::new(SortMergeJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_on, + sm_join.join_type, + sm_join.sort_options.clone(), + sm_join.null_equals_null, + )?))) +} + +/// Tries to swap the projection with its input [`SymmetricHashJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`SymmetricHashJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_sym_hash_join( + projection: &ProjectionExec, + sym_join: &SymmetricHashJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + sym_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + sym_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let Some(new_on) = update_join_on( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + sym_join.on(), + ) else { + return Ok(None); + }; + + let new_filter = if let Some(filter) = sym_join.filter() { + match update_join_filter( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + filter, + sym_join.left(), + sym_join.right(), + ) { + Some(updated_filter) => Some(updated_filter), + None => return Ok(None), + } + } else { + None + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + sym_join.left(), + sym_join.right(), + )?; + + Ok(Some(Arc::new(SymmetricHashJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_on, + new_filter, + sym_join.join_type(), + sym_join.null_equals_null(), + sym_join.partition_mode(), + )?))) +} + +/// Compare the inputs and outputs of the projection. If the projection causes +/// any change in the fields, it returns `false`. +fn is_projection_removable(projection: &ProjectionExec) -> bool { + all_alias_free_columns(projection.expr()) && { + let schema = projection.schema(); + let input_schema = projection.input().schema(); + let fields = schema.fields(); + let input_fields = input_schema.fields(); + fields.len() == input_fields.len() + && fields + .iter() + .zip(input_fields.iter()) + .all(|(out, input)| out.eq(input)) + } +} + +/// Given the expression set of a projection, checks if the projection causes +/// any renaming or constructs a non-`Column` physical expression. +fn all_alias_free_columns(exprs: &[(Arc, String)]) -> bool { + exprs.iter().all(|(expr, alias)| { + expr.as_any() + .downcast_ref::() + .map(|column| column.name() == alias) + .unwrap_or(false) + }) +} + +/// Updates a source provider's projected columns according to the given +/// projection operator's expressions. To use this function safely, one must +/// ensure that all expressions are `Column` expressions without aliases. +fn new_projections_for_columns( + projection: &ProjectionExec, + source: &Option>, +) -> Vec { + projection + .expr() + .iter() + .filter_map(|(expr, _)| { + expr.as_any() + .downcast_ref::() + .and_then(|expr| source.as_ref().map(|proj| proj[expr.index()])) + }) + .collect() +} + +/// The function operates in two modes: +/// +/// 1) When `sync_with_child` is `true`: +/// +/// The function updates the indices of `expr` if the expression resides +/// in the input plan. For instance, given the expressions `a@1 + b@2` +/// and `c@0` with the input schema `c@2, a@0, b@1`, the expressions are +/// updated to `a@0 + b@1` and `c@2`. +/// +/// 2) When `sync_with_child` is `false`: +/// +/// The function determines how the expression would be updated if a projection +/// was placed before the plan associated with the expression. If the expression +/// cannot be rewritten after the projection, it returns `None`. For example, +/// given the expressions `c@0`, `a@1` and `b@2`, and the [`ProjectionExec`] with +/// an output schema of `a, c_new`, then `c@0` becomes `c_new@1`, `a@1` becomes +/// `a@0`, but `b@2` results in `None` since the projection does not include `b`. +fn update_expr( + expr: &Arc, + projected_exprs: &[(Arc, String)], + sync_with_child: bool, +) -> Result>> { + #[derive(Debug, PartialEq)] + enum RewriteState { + /// The expression is unchanged. + Unchanged, + /// Some part of the expression has been rewritten + RewrittenValid, + /// Some part of the expression has been rewritten, but some column + /// references could not be. + RewrittenInvalid, + } + + let mut state = RewriteState::Unchanged; + + let new_expr = expr + .clone() + .transform_up_mut(&mut |expr: Arc| { + if state == RewriteState::RewrittenInvalid { + return Ok(Transformed::No(expr)); + } + + let Some(column) = expr.as_any().downcast_ref::() else { + return Ok(Transformed::No(expr)); + }; + if sync_with_child { + state = RewriteState::RewrittenValid; + // Update the index of `column`: + Ok(Transformed::Yes(projected_exprs[column.index()].0.clone())) + } else { + // default to invalid, in case we can't find the relevant column + state = RewriteState::RewrittenInvalid; + // Determine how to update `column` to accommodate `projected_exprs` + projected_exprs + .iter() + .enumerate() + .find_map(|(index, (projected_expr, alias))| { + projected_expr.as_any().downcast_ref::().and_then( + |projected_column| { + column.name().eq(projected_column.name()).then(|| { + state = RewriteState::RewrittenValid; + Arc::new(Column::new(alias, index)) as _ + }) + }, + ) + }) + .map_or_else( + || Ok(Transformed::No(expr)), + |c| Ok(Transformed::Yes(c)), + ) + } + }); + + new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e)) +} + +/// Creates a new [`ProjectionExec`] instance with the given child plan and +/// projected expressions. +fn make_with_child( + projection: &ProjectionExec, + child: &Arc, +) -> Result> { + ProjectionExec::try_new(projection.expr().to_vec(), child.clone()) + .map(|e| Arc::new(e) as _) +} + +/// Returns `true` if all the expressions in the argument are `Column`s. +fn all_columns(exprs: &[(Arc, String)]) -> bool { + exprs.iter().all(|(expr, _)| expr.as_any().is::()) +} + +/// Downcasts all the expressions in `exprs` to `Column`s. If any of the given +/// expressions is not a `Column`, returns `None`. +fn physical_to_column_exprs( + exprs: &[(Arc, String)], +) -> Option> { + exprs + .iter() + .map(|(expr, alias)| { + expr.as_any() + .downcast_ref::() + .map(|col| (col.clone(), alias.clone())) + }) + .collect() +} + +/// Returns the last index before encountering a column coming from the right table when traveling +/// through the projection from left to right, and the last index before encountering a column +/// coming from the left table when traveling through the projection from right to left. +/// If there is no column in the projection coming from the left side, it returns (-1, ...), +/// if there is no column in the projection coming from the right side, it returns (..., projection length). +fn join_table_borders( + left_table_column_count: usize, + projection_as_columns: &[(Column, String)], +) -> (i32, i32) { + let far_right_left_col_ind = projection_as_columns + .iter() + .enumerate() + .take_while(|(_, (projection_column, _))| { + projection_column.index() < left_table_column_count + }) + .last() + .map(|(index, _)| index as i32) + .unwrap_or(-1); + + let far_left_right_col_ind = projection_as_columns + .iter() + .enumerate() + .rev() + .take_while(|(_, (projection_column, _))| { + projection_column.index() >= left_table_column_count + }) + .last() + .map(|(index, _)| index as i32) + .unwrap_or(projection_as_columns.len() as i32); + + (far_right_left_col_ind, far_left_right_col_ind) +} + +/// Tries to update the equi-join `Column`'s of a join as if the the input of +/// the join was replaced by a projection. +fn update_join_on( + proj_left_exprs: &[(Column, String)], + proj_right_exprs: &[(Column, String)], + hash_join_on: &[(Column, Column)], +) -> Option> { + // TODO: Clippy wants the "map" call removed, but doing so generates + // a compilation error. Remove the clippy directive once this + // issue is fixed. + #[allow(clippy::map_identity)] + let (left_idx, right_idx): (Vec<_>, Vec<_>) = hash_join_on + .iter() + .map(|(left, right)| (left, right)) + .unzip(); + + let new_left_columns = new_columns_for_join_on(&left_idx, proj_left_exprs); + let new_right_columns = new_columns_for_join_on(&right_idx, proj_right_exprs); + + match (new_left_columns, new_right_columns) { + (Some(left), Some(right)) => Some(left.into_iter().zip(right).collect()), + _ => None, + } +} + +/// This function generates a new set of columns to be used in a hash join +/// operation based on a set of equi-join conditions (`hash_join_on`) and a +/// list of projection expressions (`projection_exprs`). +fn new_columns_for_join_on( + hash_join_on: &[&Column], + projection_exprs: &[(Column, String)], +) -> Option> { + let new_columns = hash_join_on + .iter() + .filter_map(|on| { + projection_exprs + .iter() + .enumerate() + .find(|(_, (proj_column, _))| on.name() == proj_column.name()) + .map(|(index, (_, alias))| Column::new(alias, index)) + }) + .collect::>(); + (new_columns.len() == hash_join_on.len()).then_some(new_columns) +} + +/// Tries to update the column indices of a [`JoinFilter`] as if the the input of +/// the join was replaced by a projection. +fn update_join_filter( + projection_left_exprs: &[(Column, String)], + projection_right_exprs: &[(Column, String)], + join_filter: &JoinFilter, + join_left: &Arc, + join_right: &Arc, +) -> Option { + let mut new_left_indices = new_indices_for_join_filter( + join_filter, + JoinSide::Left, + projection_left_exprs, + join_left.schema(), + ) + .into_iter(); + let mut new_right_indices = new_indices_for_join_filter( + join_filter, + JoinSide::Right, + projection_right_exprs, + join_right.schema(), + ) + .into_iter(); + + // Check if all columns match: + (new_right_indices.len() + new_left_indices.len() + == join_filter.column_indices().len()) + .then(|| { + JoinFilter::new( + join_filter.expression().clone(), + join_filter + .column_indices() + .iter() + .map(|col_idx| ColumnIndex { + index: if col_idx.side == JoinSide::Left { + new_left_indices.next().unwrap() + } else { + new_right_indices.next().unwrap() + }, + side: col_idx.side, + }) + .collect(), + join_filter.schema().clone(), + ) + }) +} + +/// This function determines and returns a vector of indices representing the +/// positions of columns in `projection_exprs` that are involved in `join_filter`, +/// and correspond to a particular side (`join_side`) of the join operation. +fn new_indices_for_join_filter( + join_filter: &JoinFilter, + join_side: JoinSide, + projection_exprs: &[(Column, String)], + join_child_schema: SchemaRef, +) -> Vec { + join_filter + .column_indices() + .iter() + .filter(|col_idx| col_idx.side == join_side) + .filter_map(|col_idx| { + projection_exprs.iter().position(|(col, _)| { + col.name() == join_child_schema.fields()[col_idx.index].name() + }) + }) + .collect() +} + +/// Checks three conditions for pushing a projection down through a join: +/// - Projection must narrow the join output schema. +/// - Columns coming from left/right tables must be collected at the left/right +/// sides of the output table. +/// - Left or right table is not lost after the projection. +fn join_allows_pushdown( + projection_as_columns: &[(Column, String)], + join_schema: SchemaRef, + far_right_left_col_ind: i32, + far_left_right_col_ind: i32, +) -> bool { + // Projection must narrow the join output: + projection_as_columns.len() < join_schema.fields().len() + // Are the columns from different tables mixed? + && (far_right_left_col_ind + 1 == far_left_right_col_ind) + // Left or right table is not lost after the projection. + && far_right_left_col_ind >= 0 + && far_left_right_col_ind < projection_as_columns.len() as i32 +} + +/// If pushing down the projection over this join's children seems possible, +/// this function constructs the new [`ProjectionExec`]s that will come on top +/// of the original children of the join. +fn new_join_children( + projection_as_columns: Vec<(Column, String)>, + far_right_left_col_ind: i32, + far_left_right_col_ind: i32, + left_child: &Arc, + right_child: &Arc, +) -> Result<(ProjectionExec, ProjectionExec)> { + let new_left = ProjectionExec::try_new( + projection_as_columns[0..=far_right_left_col_ind as _] + .iter() + .map(|(col, alias)| { + ( + Arc::new(Column::new(col.name(), col.index())) as _, + alias.clone(), + ) + }) + .collect_vec(), + left_child.clone(), + )?; + let left_size = left_child.schema().fields().len() as i32; + let new_right = ProjectionExec::try_new( + projection_as_columns[far_left_right_col_ind as _..] + .iter() + .map(|(col, alias)| { + ( + Arc::new(Column::new( + col.name(), + // Align projected expressions coming from the right + // table with the new right child projection: + (col.index() as i32 - left_size) as _, + )) as _, + alias.clone(), + ) + }) + .collect_vec(), + right_child.clone(), + )?; + + Ok((new_left, new_right)) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::datasource::file_format::file_compression_type::FileCompressionType; + use crate::datasource::listing::PartitionedFile; + use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; + use crate::physical_optimizer::output_requirements::OutputRequirementExec; + use crate::physical_optimizer::projection_pushdown::{ + join_table_borders, update_expr, ProjectionPushdown, + }; + use crate::physical_optimizer::PhysicalOptimizerRule; + use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; + use crate::physical_plan::filter::FilterExec; + use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; + use crate::physical_plan::joins::StreamJoinPartitionMode; + use crate::physical_plan::memory::MemoryExec; + use crate::physical_plan::projection::ProjectionExec; + use crate::physical_plan::repartition::RepartitionExec; + use crate::physical_plan::sorts::sort::SortExec; + use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; + use crate::physical_plan::{get_plan_string, ExecutionPlan}; + + use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; + use datafusion_common::config::ConfigOptions; + use datafusion_common::{JoinSide, JoinType, Result, ScalarValue, Statistics}; + use datafusion_execution::object_store::ObjectStoreUrl; + use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + use datafusion_expr::{ColumnarValue, Operator}; + use datafusion_physical_expr::expressions::{ + BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, + }; + use datafusion_physical_expr::{ + Distribution, Partitioning, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, ScalarFunctionExpr, + }; + use datafusion_physical_plan::joins::SymmetricHashJoinExec; + use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; + use datafusion_physical_plan::union::UnionExec; + + use itertools::Itertools; + + #[test] + fn test_update_matching_exprs() -> Result<()> { + let exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Divide, + Arc::new(Column::new("e", 5)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 3)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Divide, + Arc::new(Column::new("c", 0)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Divide, + Arc::new(Column::new("b", 1)), + )), + ], + DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d", 2))), + vec![ + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 2)), + Operator::Plus, + Arc::new(Column::new("e", 5)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 5)), + Operator::Plus, + Arc::new(Column::new("d", 2)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Modulo, + Arc::new(Column::new("e", 5)), + ))), + )?), + ]; + let child: Vec<(Arc, String)> = vec![ + (Arc::new(Column::new("c", 2)), "c".to_owned()), + (Arc::new(Column::new("b", 1)), "b".to_owned()), + (Arc::new(Column::new("d", 3)), "d".to_owned()), + (Arc::new(Column::new("a", 0)), "a".to_owned()), + (Arc::new(Column::new("f", 5)), "f".to_owned()), + (Arc::new(Column::new("e", 4)), "e".to_owned()), + ]; + + let expected_exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Divide, + Arc::new(Column::new("e", 4)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 0)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Divide, + Arc::new(Column::new("c", 2)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Divide, + Arc::new(Column::new("b", 1)), + )), + ], + DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d", 3))), + vec![ + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 3)), + Operator::Plus, + Arc::new(Column::new("e", 4)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 4)), + Operator::Plus, + Arc::new(Column::new("d", 3)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Modulo, + Arc::new(Column::new("e", 4)), + ))), + )?), + ]; + + for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { + assert!(update_expr(&expr, &child, true)? + .unwrap() + .eq(&expected_expr)); + } + + Ok(()) + } + + #[test] + fn test_update_projected_exprs() -> Result<()> { + let exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Divide, + Arc::new(Column::new("e", 5)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 3)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Divide, + Arc::new(Column::new("c", 0)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Divide, + Arc::new(Column::new("b", 1)), + )), + ], + DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d", 2))), + vec![ + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 2)), + Operator::Plus, + Arc::new(Column::new("e", 5)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 5)), + Operator::Plus, + Arc::new(Column::new("d", 2)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Modulo, + Arc::new(Column::new("e", 5)), + ))), + )?), + ]; + let projected_exprs: Vec<(Arc, String)> = vec![ + (Arc::new(Column::new("a", 0)), "a".to_owned()), + (Arc::new(Column::new("b", 1)), "b_new".to_owned()), + (Arc::new(Column::new("c", 2)), "c".to_owned()), + (Arc::new(Column::new("d", 3)), "d_new".to_owned()), + (Arc::new(Column::new("e", 4)), "e".to_owned()), + (Arc::new(Column::new("f", 5)), "f_new".to_owned()), + ]; + + let expected_exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Divide, + Arc::new(Column::new("e", 4)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 0)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b_new", 1)), + Operator::Divide, + Arc::new(Column::new("c", 2)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Divide, + Arc::new(Column::new("b_new", 1)), + )), + ], + DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d_new", 3))), + vec![ + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d_new", 3)), + Operator::Plus, + Arc::new(Column::new("e", 4)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 4)), + Operator::Plus, + Arc::new(Column::new("d_new", 3)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Modulo, + Arc::new(Column::new("e", 4)), + ))), + )?), + ]; + + for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { + assert!(update_expr(&expr, &projected_exprs, false)? + .unwrap() + .eq(&expected_expr)); + } + + Ok(()) + } + + #[test] + fn test_join_table_borders() -> Result<()> { + let projections = vec![ + (Column::new("b", 1), "b".to_owned()), + (Column::new("c", 2), "c".to_owned()), + (Column::new("e", 4), "e".to_owned()), + (Column::new("d", 3), "d".to_owned()), + (Column::new("c", 2), "c".to_owned()), + (Column::new("f", 5), "f".to_owned()), + (Column::new("h", 7), "h".to_owned()), + (Column::new("g", 6), "g".to_owned()), + ]; + let left_table_column_count = 5; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (4, 5) + ); + + let left_table_column_count = 8; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (7, 8) + ); + + let left_table_column_count = 1; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (-1, 0) + ); + + let projections = vec![ + (Column::new("a", 0), "a".to_owned()), + (Column::new("b", 1), "b".to_owned()), + (Column::new("d", 3), "d".to_owned()), + (Column::new("g", 6), "g".to_owned()), + (Column::new("e", 4), "e".to_owned()), + (Column::new("f", 5), "f".to_owned()), + (Column::new("e", 4), "e".to_owned()), + (Column::new("h", 7), "h".to_owned()), + ]; + let left_table_column_count = 5; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (2, 7) + ); + + let left_table_column_count = 7; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (6, 7) + ); + + Ok(()) + } + + fn create_simple_csv_exec() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])); + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::new_unknown(&schema), + projection: Some(vec![0, 1, 2, 3, 4]), + limit: None, + table_partition_cols: vec![], + output_ordering: vec![vec![]], + }, + false, + 0, + 0, + None, + FileCompressionType::UNCOMPRESSED, + )) + } + + fn create_projecting_csv_exec() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + ])); + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::new_unknown(&schema), + projection: Some(vec![3, 2, 1]), + limit: None, + table_partition_cols: vec![], + output_ordering: vec![vec![]], + }, + false, + 0, + 0, + None, + FileCompressionType::UNCOMPRESSED, + )) + } + + fn create_projecting_memory_exec() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])); + + Arc::new(MemoryExec::try_new(&[], schema, Some(vec![2, 0, 3, 4])).unwrap()) + } + + #[test] + fn test_csv_after_projection() -> Result<()> { + let csv = create_projecting_csv_exec(); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("b", 2)), "b".to_string()), + (Arc::new(Column::new("d", 0)), "d".to_string()), + ], + csv.clone(), + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[b@2 as b, d@0 as d]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[d, c, b], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "CsvExec: file_groups={1 group: [[x]]}, projection=[b, d], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_memory_after_projection() -> Result<()> { + let memory = create_projecting_memory_exec(); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("d", 2)), "d".to_string()), + (Arc::new(Column::new("e", 3)), "e".to_string()), + (Arc::new(Column::new("a", 1)), "a".to_string()), + ], + memory.clone(), + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[d@2 as d, e@3 as e, a@1 as a]", + " MemoryExec: partitions=0, partition_sizes=[]", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = ["MemoryExec: partitions=0, partition_sizes=[]"]; + assert_eq!(get_plan_string(&after_optimize), expected); + assert_eq!( + after_optimize + .clone() + .as_any() + .downcast_ref::() + .unwrap() + .projection() + .clone() + .unwrap(), + vec![3, 4, 0] + ); + + Ok(()) + } + + #[test] + fn test_streaming_table_after_projection() -> Result<()> { + struct DummyStreamPartition { + schema: SchemaRef, + } + impl PartitionStream for DummyStreamPartition { + fn schema(&self) -> &SchemaRef { + &self.schema + } + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + unreachable!() + } + } + + let streaming_table = StreamingTableExec::try_new( + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])), + vec![Arc::new(DummyStreamPartition { + schema: Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])), + }) as _], + Some(&vec![0_usize, 2, 4, 3]), + vec![ + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("e", 2)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }, + ], + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("d", 3)), + options: SortOptions::default(), + }], + ] + .into_iter(), + true, + )?; + let projection = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("d", 3)), "d".to_string()), + (Arc::new(Column::new("e", 2)), "e".to_string()), + (Arc::new(Column::new("a", 0)), "a".to_string()), + ], + Arc::new(streaming_table) as _, + )?) as _; + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let result = after_optimize + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + result.partition_schema(), + &Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])) + ); + assert_eq!( + result.projection().clone().unwrap().to_vec(), + vec![3_usize, 4, 0] + ); + assert_eq!( + result.projected_schema(), + &Schema::new(vec![ + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("a", DataType::Int32, true), + ]) + ); + assert_eq!( + result.projected_output_ordering().into_iter().collect_vec(), + vec![ + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("e", 1)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 2)), + options: SortOptions::default(), + }, + ], + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("d", 0)), + options: SortOptions::default(), + }], + ] + ); + assert!(result.is_infinite()); + + Ok(()) + } + + #[test] + fn test_projection_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let child_projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("e", 4)), "new_e".to_string()), + (Arc::new(Column::new("a", 0)), "a".to_string()), + (Arc::new(Column::new("b", 1)), "new_b".to_string()), + ], + csv.clone(), + )?); + let top_projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("new_b", 3)), "new_b".to_string()), + ( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Plus, + Arc::new(Column::new("new_e", 1)), + )), + "binary".to_string(), + ), + (Arc::new(Column::new("new_b", 3)), "newest_b".to_string()), + ], + child_projection.clone(), + )?); + + let initial = get_plan_string(&top_projection); + let expected_initial = [ + "ProjectionExec: expr=[new_b@3 as new_b, c@0 + new_e@1 as binary, new_b@3 as newest_b]", + " ProjectionExec: expr=[c@2 as c, e@4 as new_e, a@0 as a, b@1 as new_b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(top_projection, &ConfigOptions::new())?; + + let expected = [ + "ProjectionExec: expr=[b@1 as new_b, c@2 + e@4 as binary, b@1 as newest_b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_output_req_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let sort_req: Arc = Arc::new(OutputRequirementExec::new( + csv.clone(), + Some(vec![ + PhysicalSortRequirement { + expr: Arc::new(Column::new("b", 1)), + options: Some(SortOptions::default()), + }, + PhysicalSortRequirement { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + options: Some(SortOptions::default()), + }, + ]), + Distribution::HashPartitioned(vec![ + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + ]), + )); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + sort_req.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " OutputRequirementExec", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected: [&str; 3] = [ + "OutputRequirementExec", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + + assert_eq!(get_plan_string(&after_optimize), expected); + let expected_reqs = vec![ + PhysicalSortRequirement { + expr: Arc::new(Column::new("b", 2)), + options: Some(SortOptions::default()), + }, + PhysicalSortRequirement { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Plus, + Arc::new(Column::new("new_a", 1)), + )), + options: Some(SortOptions::default()), + }, + ]; + assert_eq!( + after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .required_input_ordering()[0] + .clone() + .unwrap(), + expected_reqs + ); + let expected_distribution: Vec> = vec![ + Arc::new(Column::new("new_a", 1)), + Arc::new(Column::new("b", 2)), + ]; + if let Distribution::HashPartitioned(vec) = after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .required_input_distribution()[0] + .clone() + { + assert!(vec + .iter() + .zip(expected_distribution) + .all(|(actual, expected)| actual.eq(&expected))); + } else { + panic!("Expected HashPartitioned distribution!"); + }; + + Ok(()) + } + + #[test] + fn test_coalesce_partitions_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let coalesce_partitions: Arc = + Arc::new(CoalescePartitionsExec::new(csv)); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("b", 1)), "b".to_string()), + (Arc::new(Column::new("a", 0)), "a_new".to_string()), + (Arc::new(Column::new("d", 3)), "d".to_string()), + ], + coalesce_partitions, + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d]", + " CoalescePartitionsExec", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "CoalescePartitionsExec", + " ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_filter_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let predicate = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Minus, + Arc::new(Column::new("a", 0)), + )), + Operator::Gt, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 3)), + Operator::Minus, + Arc::new(Column::new("a", 0)), + )), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, csv)?); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("a", 0)), "a_new".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + (Arc::new(Column::new("d", 3)), "d".to_string()), + ], + filter.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d]", + " FilterExec: b@1 - a@0 > d@3 - a@0", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "FilterExec: b@1 - a_new@0 > d@2 - a_new@0", + " ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_join_after_projection() -> Result<()> { + let left_csv = create_simple_csv_exec(); + let right_csv = create_simple_csv_exec(); + + let join: Arc = Arc::new(SymmetricHashJoinExec::try_new( + left_csv, + right_csv, + vec![(Column::new("b", 1), Column::new("c", 2))], + // b_left-(1+a_right)<=a_right+c_left + Some(JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b_left_inter", 0)), + Operator::Minus, + Arc::new(BinaryExpr::new( + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Operator::Plus, + Arc::new(Column::new("a_right_inter", 1)), + )), + )), + Operator::LtEq, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a_right_inter", 1)), + Operator::Plus, + Arc::new(Column::new("c_left_inter", 2)), + )), + )), + vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ], + Schema::new(vec![ + Field::new("b_left_inter", DataType::Int32, true), + Field::new("a_right_inter", DataType::Int32, true), + Field::new("c_left_inter", DataType::Int32, true), + ]), + )), + &JoinType::Inner, + true, + StreamJoinPartitionMode::SinglePartition, + )?); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c_from_left".to_string()), + (Arc::new(Column::new("b", 1)), "b_from_left".to_string()), + (Arc::new(Column::new("a", 0)), "a_from_left".to_string()), + (Arc::new(Column::new("a", 5)), "a_from_right".to_string()), + (Arc::new(Column::new("c", 7)), "c_from_right".to_string()), + ], + join, + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, a@5 as a_from_right, c@7 as c_from_right]", + " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b_from_left@1, c_from_right@1)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", + " ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " ProjectionExec: expr=[a@0 as a_from_right, c@2 as c_from_right]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + let expected_filter_col_ind = vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ]; + + assert_eq!( + expected_filter_col_ind, + after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .filter() + .unwrap() + .column_indices() + ); + + Ok(()) + } + + #[test] + fn test_repartition_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let repartition: Arc = Arc::new(RepartitionExec::try_new( + csv, + Partitioning::Hash( + vec![ + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("d", 3)), + ], + 6, + ), + )?); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("b", 1)), "b_new".to_string()), + (Arc::new(Column::new("a", 0)), "a".to_string()), + (Arc::new(Column::new("d", 3)), "d_new".to_string()), + ], + repartition, + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", + " RepartitionExec: partitioning=Hash([a@0, b@1, d@3], 6), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "RepartitionExec: partitioning=Hash([a@1, b_new@0, d_new@2], 6), input_partitions=1", + " ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + assert_eq!( + after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .partitioning() + .clone(), + Partitioning::Hash( + vec![ + Arc::new(Column::new("a", 1)), + Arc::new(Column::new("b_new", 0)), + Arc::new(Column::new("d_new", 2)), + ], + 6, + ), + ); + + Ok(()) + } + + #[test] + fn test_sort_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let sort_req: Arc = Arc::new(SortExec::new( + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + options: SortOptions::default(), + }, + ], + csv.clone(), + )); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + sort_req.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " SortExec: expr=[b@1 ASC,c@2 + a@0 ASC]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "SortExec: expr=[b@2 ASC,c@0 + new_a@1 ASC]", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_sort_preserving_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let sort_req: Arc = Arc::new(SortPreservingMergeExec::new( + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + options: SortOptions::default(), + }, + ], + csv.clone(), + )); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + sort_req.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " SortPreservingMergeExec: [b@1 ASC,c@2 + a@0 ASC]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "SortPreservingMergeExec: [b@2 ASC,c@0 + new_a@1 ASC]", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_union_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let union: Arc = + Arc::new(UnionExec::new(vec![csv.clone(), csv.clone(), csv])); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + union.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " UnionExec", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "UnionExec", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 2987ec6d6552..0cbbaf2bf6cd 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -15,19 +15,10 @@ // specific language governing permissions and limitations // under the License. -//! This module contains code to prune "containers" of row groups -//! based on statistics prior to execution. This can lead to -//! significant performance improvements by avoiding the need -//! to evaluate a plan on entire containers (e.g. an entire file) +//! [`PruningPredicate`] to apply filter [`Expr`] to prune "containers" +//! based on statistics (e.g. Parquet Row Groups) //! -//! For example, DataFusion uses this code to prune (skip) row groups -//! while reading parquet files if it can be determined from the -//! predicate that nothing in the row group can match. -//! -//! This code can also be used by other systems to prune other -//! entities (e.g. entire files) if the statistics are known via some -//! other source (e.g. a catalog) - +//! [`Expr`]: crate::prelude::Expr use std::collections::HashSet; use std::convert::TryFrom; use std::sync::Arc; @@ -44,27 +35,42 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; -use datafusion_common::{downcast_value, ScalarValue}; +use arrow_array::cast::AsArray; use datafusion_common::{ internal_err, plan_err, tree_node::{Transformed, TreeNode}, }; -use datafusion_physical_expr::utils::collect_columns; +use datafusion_common::{plan_datafusion_err, ScalarValue}; +use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use log::trace; -/// Interface to pass statistics information to [`PruningPredicate`] +/// A source of runtime statistical information to [`PruningPredicate`]s. +/// +/// # Supported Information +/// +/// 1. Minimum and maximum values for columns /// -/// Returns statistics for containers / files of data in Arrays. +/// 2. Null counts for columns /// -/// For example, for the following three files with a single column +/// 3. Whether the values in a column are contained in a set of literals +/// +/// # Vectorized Interface +/// +/// Information for containers / files are returned as Arrow [`ArrayRef`], so +/// the evaluation happens once on a single `RecordBatch`, which amortizes the +/// overhead of evaluating the predicate. This is important when pruning 1000s +/// of containers which often happens in analytic systems that have 1000s of +/// potential files to consider. +/// +/// For example, for the following three files with a single column `a`: /// ```text /// file1: column a: min=5, max=10 /// file2: column a: No stats /// file2: column a: min=20, max=30 /// ``` /// -/// PruningStatistics should return: +/// PruningStatistics would return: /// /// ```text /// min_values("a") -> Some([5, Null, 20]) @@ -72,39 +78,125 @@ use log::trace; /// min_values("X") -> None /// ``` pub trait PruningStatistics { - /// return the minimum values for the named column, if known. - /// Note: the returned array must contain `num_containers()` rows + /// Return the minimum values for the named column, if known. + /// + /// If the minimum value for a particular container is not known, the + /// returned array should have `null` in that row. If the minimum value is + /// not known for any row, return `None`. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows fn min_values(&self, column: &Column) -> Option; - /// return the maximum values for the named column, if known. - /// Note: the returned array must contain `num_containers()` rows. + /// Return the maximum values for the named column, if known. + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows fn max_values(&self, column: &Column) -> Option; - /// return the number of containers (e.g. row groups) being - /// pruned with these statistics + /// Return the number of containers (e.g. Row Groups) being pruned with + /// these statistics. + /// + /// This value corresponds to the size of the [`ArrayRef`] returned by + /// [`Self::min_values`], [`Self::max_values`], and [`Self::null_counts`]. fn num_containers(&self) -> usize; - /// return the number of null values for the named column as an + /// Return the number of null values for the named column as an /// `Option`. /// - /// Note: the returned array must contain `num_containers()` rows. + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows fn null_counts(&self, column: &Column) -> Option; + + /// Returns [`BooleanArray`] where each row represents information known + /// about specific literal `values` in a column. + /// + /// For example, Parquet Bloom Filters implement this API to communicate + /// that `values` are known not to be present in a Row Group. + /// + /// The returned array has one row for each container, with the following + /// meanings: + /// * `true` if the values in `column` ONLY contain values from `values` + /// * `false` if the values in `column` are NOT ANY of `values` + /// * `null` if the neither of the above holds or is unknown. + /// + /// If these statistics can not determine column membership for any + /// container, return `None` (the default). + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option; } -/// Evaluates filter expressions on statistics in order to -/// prune data containers (e.g. parquet row group) +/// Used to prove that arbitrary predicates (boolean expression) can not +/// possibly evaluate to `true` given information about a column provided by +/// [`PruningStatistics`]. +/// +/// `PruningPredicate` analyzes filter expressions using statistics such as +/// min/max values and null counts, attempting to prove a "container" (e.g. +/// Parquet Row Group) can be skipped without reading the actual data, +/// potentially leading to significant performance improvements. +/// +/// For example, `PruningPredicate`s are used to prune Parquet Row Groups based +/// on the min/max values found in the Parquet metadata. If the +/// `PruningPredicate` can prove that the filter can never evaluate to `true` +/// for any row in the Row Group, the entire Row Group is skipped during query +/// execution. +/// +/// The `PruningPredicate` API is designed to be general, so it can used for +/// pruning other types of containers (e.g. files) based on statistics that may +/// be known from external catalogs (e.g. Delta Lake) or other sources. +/// +/// It currently supports: +/// +/// 1. Arbitrary expressions (including user defined functions) +/// +/// 2. Vectorized evaluation (provide more than one set of statistics at a time) +/// so it is suitable for pruning 1000s of containers. +/// +/// 3. Any source of information that implements the [`PruningStatistics`] trait +/// (not just Parquet metadata). +/// +/// # Example +/// +/// Given an expression like `x = 5` and statistics for 3 containers (Row +/// Groups, files, etc) `A`, `B`, and `C`: +/// +/// ```text +/// A: {x_min = 0, x_max = 4} +/// B: {x_min = 2, x_max = 10} +/// C: {x_min = 5, x_max = 8} +/// ``` +/// +/// `PruningPredicate` will conclude that the rows in container `A` can never +/// be true (as the maximum value is only `4`), so it can be pruned: +/// +/// ```text +/// A: false (no rows could possibly match x = 5) +/// B: true (rows might match x = 5) +/// C: true (rows might match x = 5) +/// ``` /// -/// See [`PruningPredicate::try_new`] for more information. +/// See [`PruningPredicate::try_new`] and [`PruningPredicate::prune`] for more information. #[derive(Debug, Clone)] pub struct PruningPredicate { /// The input schema against which the predicate will be evaluated schema: SchemaRef, - /// Actual pruning predicate (rewritten in terms of column min/max statistics) + /// A min/max pruning predicate (rewritten in terms of column min/max + /// values, which are supplied by statistics) predicate_expr: Arc, - /// The statistics required to evaluate this predicate - required_columns: RequiredStatColumns, - /// Original physical predicate from which this predicate expr is derived (required for serialization) + /// Description of which statistics are required to evaluate `predicate_expr` + required_columns: RequiredColumns, + /// Original physical predicate from which this predicate expr is derived + /// (required for serialization) orig_expr: Arc, + /// [`LiteralGuarantee`]s that are used to try and prove a predicate can not + /// possibly evaluate to `true`. + literal_guarantees: Vec, } impl PruningPredicate { @@ -129,14 +221,18 @@ impl PruningPredicate { /// `(column_min / 2) <= 4 && 4 <= (column_max / 2))` pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { // build predicate expression once - let mut required_columns = RequiredStatColumns::new(); + let mut required_columns = RequiredColumns::new(); let predicate_expr = build_predicate_expression(&expr, schema.as_ref(), &mut required_columns); + + let literal_guarantees = LiteralGuarantee::analyze(&expr); + Ok(Self { schema, predicate_expr, required_columns, orig_expr: expr, + literal_guarantees, }) } @@ -146,52 +242,61 @@ impl PruningPredicate { /// /// `true`: There MAY be rows that match the predicate /// - /// `false`: There are no rows that could match the predicate + /// `false`: There are no rows that could possibly match the predicate /// - /// Note this function takes a slice of statistics as a parameter - /// to amortize the cost of the evaluation of the predicate - /// against a single record batch. - /// - /// Note: the predicate passed to `prune` should be simplified as + /// Note: the predicate passed to `prune` should already be simplified as /// much as possible (e.g. this pass doesn't handle some /// expressions like `b = false`, but it does handle the - /// simplified version `b`. The predicates are simplified via the - /// ConstantFolding optimizer pass + /// simplified version `b`. See [`ExprSimplifier`] to simplify expressions. + /// + /// [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier pub fn prune(&self, statistics: &S) -> Result> { + let mut builder = BoolVecBuilder::new(statistics.num_containers()); + + // Try to prove the predicate can't be true for the containers based on + // literal guarantees + for literal_guarantee in &self.literal_guarantees { + let LiteralGuarantee { + column, + guarantee, + literals, + } = literal_guarantee; + if let Some(results) = statistics.contained(column, literals) { + match guarantee { + // `In` means the values in the column must be one of the + // values in the set for the predicate to evaluate to true. + // If `contained` returns false, that means the column is + // not any of the values so we can prune the container + Guarantee::In => builder.combine_array(&results), + // `NotIn` means the values in the column must must not be + // any of the values in the set for the predicate to + // evaluate to true. If contained returns true, it means the + // column is only in the set of values so we can prune the + // container + Guarantee::NotIn => { + builder.combine_array(&arrow::compute::not(&results)?) + } + } + // if all containers are pruned (has rows that DEFINITELY DO NOT pass the predicate) + // can return early without evaluating the rest of predicates. + if builder.check_all_pruned() { + return Ok(builder.build()); + } + } + } + + // Next, try to prove the predicate can't be true for the containers based + // on min/max values + // build a RecordBatch that contains the min/max values in the - // appropriate statistics columns + // appropriate statistics columns for the min/max predicate let statistics_batch = build_statistics_record_batch(statistics, &self.required_columns)?; - // Evaluate the pruning predicate on that record batch. - // - // Use true when the result of evaluating a predicate - // expression on a row group is null (aka `None`). Null can - // arise when the statistics are unknown or some calculation - // in the predicate means we don't know for sure if the row - // group can be filtered out or not. To maintain correctness - // the row group must be kept and thus `true` is returned. - match self.predicate_expr.evaluate(&statistics_batch)? { - ColumnarValue::Array(array) => { - let predicate_array = downcast_value!(array, BooleanArray); + // Evaluate the pruning predicate on that record batch and append any results to the builder + builder.combine_value(self.predicate_expr.evaluate(&statistics_batch)?); - Ok(predicate_array - .into_iter() - .map(|x| x.unwrap_or(true)) // None -> true per comments above - .collect::>()) - } - // result was a column - ColumnarValue::Scalar(ScalarValue::Boolean(v)) => { - let v = v.unwrap_or(true); // None -> true per comments above - Ok(vec![v; statistics.num_containers()]) - } - other => { - internal_err!( - "Unexpected result of pruning predicate evaluation. Expected Boolean array \ - or scalar but got {other:?}" - ) - } - } + Ok(builder.build()) } /// Return a reference to the input schema @@ -209,14 +314,104 @@ impl PruningPredicate { &self.predicate_expr } - /// Returns true if this pruning predicate is "always true" (aka will not prune anything) + /// Returns true if this pruning predicate can not prune anything. + /// + /// This happens if the predicate is a literal `true` and + /// literal_guarantees is empty. pub fn allways_true(&self) -> bool { - is_always_true(&self.predicate_expr) + is_always_true(&self.predicate_expr) && self.literal_guarantees.is_empty() } - pub(crate) fn required_columns(&self) -> &RequiredStatColumns { + pub(crate) fn required_columns(&self) -> &RequiredColumns { &self.required_columns } + + /// Names of the columns that are known to be / not be in a set + /// of literals (constants). These are the columns the that may be passed to + /// [`PruningStatistics::contained`] during pruning. + /// + /// This is useful to avoid fetching statistics for columns that will not be + /// used in the predicate. For example, it can be used to avoid reading + /// uneeded bloom filters (a non trivial operation). + pub fn literal_columns(&self) -> Vec { + let mut seen = HashSet::new(); + self.literal_guarantees + .iter() + .map(|e| &e.column.name) + // avoid duplicates + .filter(|name| seen.insert(*name)) + .map(|s| s.to_string()) + .collect() + } +} + +/// Builds the return `Vec` for [`PruningPredicate::prune`]. +#[derive(Debug)] +struct BoolVecBuilder { + /// One element per container. Each element is + /// * `true`: if the container has row that may pass the predicate + /// * `false`: if the container has rows that DEFINITELY DO NOT pass the predicate + inner: Vec, +} + +impl BoolVecBuilder { + /// Create a new `BoolVecBuilder` with `num_containers` elements + fn new(num_containers: usize) -> Self { + Self { + // assume by default all containers may pass the predicate + inner: vec![true; num_containers], + } + } + + /// Combines result `array` for a conjunct (e.g. `AND` clause) of a + /// predicate into the currently in progress array. + /// + /// Each `array` element is: + /// * `true`: container has row that may pass the predicate + /// * `false`: all container rows DEFINITELY DO NOT pass the predicate + /// * `null`: container may or may not have rows that pass the predicate + fn combine_array(&mut self, array: &BooleanArray) { + assert_eq!(array.len(), self.inner.len()); + for (cur, new) in self.inner.iter_mut().zip(array.iter()) { + // `false` for this conjunct means we know for sure no rows could + // pass the predicate and thus we set the corresponding container + // location to false. + if let Some(false) = new { + *cur = false; + } + } + } + + /// Combines the results in the [`ColumnarValue`] to the currently in + /// progress array, following the same rules as [`Self::combine_array`]. + /// + /// # Panics + /// If `value` is not boolean + fn combine_value(&mut self, value: ColumnarValue) { + match value { + ColumnarValue::Array(array) => { + self.combine_array(array.as_boolean()); + } + ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))) => { + // False means all containers can not pass the predicate + self.inner = vec![false; self.inner.len()]; + } + _ => { + // Null or true means the rows in container may pass this + // conjunct so we can't prune any containers based on that + } + } + } + + /// Convert this builder into a Vec of bools + fn build(self) -> Vec { + self.inner + } + + /// Check all containers has rows that DEFINITELY DO NOT pass the predicate + fn check_all_pruned(&self) -> bool { + self.inner.iter().all(|&x| !x) + } } fn is_always_true(expr: &Arc) -> bool { @@ -226,27 +421,31 @@ fn is_always_true(expr: &Arc) -> bool { .unwrap_or_default() } -/// Records for which columns statistics are necessary to evaluate a -/// pruning predicate. +/// Describes which columns statistics are necessary to evaluate a +/// [`PruningPredicate`]. +/// +/// This structure permits reading and creating the minimum number statistics, +/// which is important since statistics may be non trivial to read (e.g. large +/// strings or when there are 1000s of columns). /// /// Handles creating references to the min/max statistics /// for columns as well as recording which statistics are needed #[derive(Debug, Default, Clone)] -pub(crate) struct RequiredStatColumns { +pub(crate) struct RequiredColumns { /// The statistics required to evaluate this predicate: /// * The unqualified column in the input schema /// * Statistics type (e.g. Min or Max or Null_Count) /// * The field the statistics value should be placed in for - /// pruning predicate evaluation + /// pruning predicate evaluation (e.g. `min_value` or `max_value`) columns: Vec<(phys_expr::Column, StatisticsType, Field)>, } -impl RequiredStatColumns { +impl RequiredColumns { fn new() -> Self { Self::default() } - /// Returns number of unique columns. + /// Returns number of unique columns pub(crate) fn n_columns(&self) -> usize { self.iter() .map(|(c, _s, _f)| c) @@ -300,11 +499,10 @@ impl RequiredStatColumns { // only add statistics column if not previously added if need_to_insert { - let stat_field = Field::new( - stat_column.name(), - field.data_type().clone(), - field.is_nullable(), - ); + // may be null if statistics are not present + let nullable = true; + let stat_field = + Field::new(stat_column.name(), field.data_type().clone(), nullable); self.columns.push((column.clone(), stat_type, stat_field)); } rewrite_column_expr(column_expr.clone(), column, &stat_column) @@ -347,7 +545,7 @@ impl RequiredStatColumns { } } -impl From> for RequiredStatColumns { +impl From> for RequiredColumns { fn from(columns: Vec<(phys_expr::Column, StatisticsType, Field)>) -> Self { Self { columns } } @@ -380,7 +578,7 @@ impl From> for RequiredStatColum /// ``` fn build_statistics_record_batch( statistics: &S, - required_columns: &RequiredStatColumns, + required_columns: &RequiredColumns, ) -> Result { let mut fields = Vec::::new(); let mut arrays = Vec::::new(); @@ -426,7 +624,7 @@ fn build_statistics_record_batch( ); RecordBatch::try_new_with_options(schema, arrays, &options).map_err(|err| { - DataFusionError::Plan(format!("Can not create statistics record batch: {err}")) + plan_datafusion_err!("Can not create statistics record batch: {err}") }) } @@ -436,7 +634,7 @@ struct PruningExpressionBuilder<'a> { op: Operator, scalar_expr: Arc, field: &'a Field, - required_columns: &'a mut RequiredStatColumns, + required_columns: &'a mut RequiredColumns, } impl<'a> PruningExpressionBuilder<'a> { @@ -445,7 +643,7 @@ impl<'a> PruningExpressionBuilder<'a> { right: &'a Arc, op: Operator, schema: &'a Schema, - required_columns: &'a mut RequiredStatColumns, + required_columns: &'a mut RequiredColumns, ) -> Result { // find column name; input could be a more complicated expression let left_columns = collect_columns(left); @@ -660,7 +858,7 @@ fn reverse_operator(op: Operator) -> Result { fn build_single_column_expr( column: &phys_expr::Column, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, is_not: bool, // if true, treat as !col ) -> Option> { let field = schema.field_with_name(column.name()).ok()?; @@ -701,7 +899,7 @@ fn build_single_column_expr( fn build_is_null_column_expr( expr: &Arc, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, ) -> Option> { if let Some(col) = expr.as_any().downcast_ref::() { let field = schema.field_with_name(col.name()).ok()?; @@ -731,7 +929,7 @@ fn build_is_null_column_expr( fn build_predicate_expression( expr: &Arc, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, ) -> Arc { // Returned for unsupported expressions. Such expressions are // converted to TRUE. @@ -909,7 +1107,7 @@ fn build_statistics_expr( _ => { return plan_err!( "expressions other than (neq, eq, gt, gteq, lt, lteq) are not supported" - ) + ); } }; Ok(statistics_expr) @@ -940,7 +1138,7 @@ mod tests { use std::collections::HashMap; use std::ops::{Not, Rem}; - #[derive(Debug)] + #[derive(Debug, Default)] /// Mock statistic provider for tests /// /// Each row represents the statistics for a "container" (which @@ -949,95 +1147,142 @@ mod tests { /// /// Note All `ArrayRefs` must be the same size. struct ContainerStats { - min: ArrayRef, - max: ArrayRef, + min: Option, + max: Option, /// Optional values null_counts: Option, + /// Optional known values (e.g. mimic a bloom filter) + /// (value, contained) + /// If present, all BooleanArrays must be the same size as min/max + contained: Vec<(HashSet, BooleanArray)>, } impl ContainerStats { + fn new() -> Self { + Default::default() + } fn new_decimal128( min: impl IntoIterator>, max: impl IntoIterator>, precision: u8, scale: i8, ) -> Self { - Self { - min: Arc::new( + Self::new() + .with_min(Arc::new( min.into_iter() .collect::() .with_precision_and_scale(precision, scale) .unwrap(), - ), - max: Arc::new( + )) + .with_max(Arc::new( max.into_iter() .collect::() .with_precision_and_scale(precision, scale) .unwrap(), - ), - null_counts: None, - } + )) } fn new_i64( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn new_i32( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn new_utf8<'a>( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn new_bool( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn min(&self) -> Option { - Some(self.min.clone()) + self.min.clone() } fn max(&self) -> Option { - Some(self.max.clone()) + self.max.clone() } fn null_counts(&self) -> Option { self.null_counts.clone() } + /// return an iterator over all arrays in this statistics + fn arrays(&self) -> Vec { + let contained_arrays = self + .contained + .iter() + .map(|(_values, contained)| Arc::new(contained.clone()) as ArrayRef); + + [ + self.min.as_ref().cloned(), + self.max.as_ref().cloned(), + self.null_counts.as_ref().cloned(), + ] + .into_iter() + .flatten() + .chain(contained_arrays) + .collect() + } + + /// Returns the number of containers represented by this statistics This + /// picks the length of the first array as all arrays must have the same + /// length (which is verified by `assert_invariants`). fn len(&self) -> usize { - assert_eq!(self.min.len(), self.max.len()); - self.min.len() + // pick the first non zero length + self.arrays().iter().map(|a| a.len()).next().unwrap_or(0) + } + + /// Ensure that the lengths of all arrays are consistent + fn assert_invariants(&self) { + let mut prev_len = None; + + for len in self.arrays().iter().map(|a| a.len()) { + // Get a length, if we don't already have one + match prev_len { + None => { + prev_len = Some(len); + } + Some(prev_len) => { + assert_eq!(prev_len, len); + } + } + } + } + + /// Add min values + fn with_min(mut self, min: ArrayRef) -> Self { + self.min = Some(min); + self + } + + /// Add max values + fn with_max(mut self, max: ArrayRef) -> Self { + self.max = Some(max); + self } /// Add null counts. There must be the same number of null counts as @@ -1046,14 +1291,36 @@ mod tests { mut self, counts: impl IntoIterator>, ) -> Self { - // take stats out and update them let null_counts: ArrayRef = Arc::new(counts.into_iter().collect::()); - assert_eq!(null_counts.len(), self.len()); + self.assert_invariants(); self.null_counts = Some(null_counts); self } + + /// Add contained information. + pub fn with_contained( + mut self, + values: impl IntoIterator, + contained: impl IntoIterator>, + ) -> Self { + let contained: BooleanArray = contained.into_iter().collect(); + let values: HashSet<_> = values.into_iter().collect(); + + self.contained.push((values, contained)); + self.assert_invariants(); + self + } + + /// get any contained information for the specified values + fn contained(&self, find_values: &HashSet) -> Option { + // find the one with the matching values + self.contained + .iter() + .find(|(values, _contained)| values == find_values) + .map(|(_values, contained)| contained.clone()) + } } #[derive(Debug, Default)] @@ -1091,13 +1358,34 @@ mod tests { let container_stats = self .stats .remove(&col) - .expect("Can not find stats for column") + .unwrap_or_default() .with_null_counts(counts); // put stats back in self.stats.insert(col, container_stats); self } + + /// Add contained information for the specified columm. + fn with_contained( + mut self, + name: impl Into, + values: impl IntoIterator, + contained: impl IntoIterator>, + ) -> Self { + let col = Column::from_name(name.into()); + + // take stats out and update them + let container_stats = self + .stats + .remove(&col) + .unwrap_or_default() + .with_contained(values, contained); + + // put stats back in + self.stats.insert(col, container_stats); + self + } } impl PruningStatistics for TestStatistics { @@ -1129,6 +1417,16 @@ mod tests { .map(|container_stats| container_stats.null_counts()) .unwrap_or(None) } + + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + self.stats + .get(column) + .and_then(|container_stats| container_stats.contained(values)) + } } /// Returns the specified min/max container values @@ -1154,12 +1452,20 @@ mod tests { fn null_counts(&self, _column: &Column) -> Option { None } + + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } } #[test] fn test_build_statistics_record_batch() { // Request a record batch with of s1_min, s2_max, s3_max, s3_min - let required_columns = RequiredStatColumns::from(vec![ + let required_columns = RequiredColumns::from(vec![ // min of original column s1, named s1_min ( phys_expr::Column::new("s1", 1), @@ -1231,7 +1537,7 @@ mod tests { // which is what Parquet does // Request a record batch with of s1_min as a timestamp - let required_columns = RequiredStatColumns::from(vec![( + let required_columns = RequiredColumns::from(vec![( phys_expr::Column::new("s3", 3), StatisticsType::Min, Field::new( @@ -1263,7 +1569,7 @@ mod tests { #[test] fn test_build_statistics_no_required_stats() { - let required_columns = RequiredStatColumns::new(); + let required_columns = RequiredColumns::new(); let statistics = OneContainerStats { min_values: Some(Arc::new(Int64Array::from(vec![Some(10)]))), @@ -1281,7 +1587,7 @@ mod tests { // Test requesting a Utf8 column when the stats return some other type // Request a record batch with of s1_min as a timestamp - let required_columns = RequiredStatColumns::from(vec![( + let required_columns = RequiredColumns::from(vec![( phys_expr::Column::new("s3", 3), StatisticsType::Min, Field::new("s1_min", DataType::Utf8, true), @@ -1310,7 +1616,7 @@ mod tests { #[test] fn test_build_statistics_inconsistent_length() { // return an inconsistent length to the actual statistics arrays - let required_columns = RequiredStatColumns::from(vec![( + let required_columns = RequiredColumns::from(vec![( phys_expr::Column::new("s1", 3), StatisticsType::Min, Field::new("s1_min", DataType::Int64, true), @@ -1341,20 +1647,14 @@ mod tests { // test column on the left let expr = col("c1").eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1367,20 +1667,14 @@ mod tests { // test column on the left let expr = col("c1").not_eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).not_eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1393,20 +1687,14 @@ mod tests { // test column on the left let expr = col("c1").gt(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).lt(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1419,19 +1707,13 @@ mod tests { // test column on the left let expr = col("c1").gt_eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).lt_eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1444,20 +1726,14 @@ mod tests { // test column on the left let expr = col("c1").lt(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).gt(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1470,19 +1746,13 @@ mod tests { // test column on the left let expr = col("c1").lt_eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).gt_eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1498,11 +1768,8 @@ mod tests { // test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3"))); let expected_expr = "c1_min@0 < 1"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1517,11 +1784,8 @@ mod tests { // test OR operator joining supported c1 < 1 expression and unsupported c2 % 2 = 0 expression let expr = col("c1").lt(lit(1)).or(col("c2").rem(lit(2)).eq(lit(0))); let expected_expr = "true"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1533,11 +1797,8 @@ mod tests { let expected_expr = "true"; let expr = col("c1").not(); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1549,11 +1810,8 @@ mod tests { let expected_expr = "NOT c1_min@0 AND c1_max@1"; let expr = col("c1").not(); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1565,11 +1823,8 @@ mod tests { let expected_expr = "c1_min@0 OR c1_max@1"; let expr = col("c1"); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1583,11 +1838,8 @@ mod tests { // DF doesn't support arithmetic on boolean columns so // this predicate will error when evaluated let expr = col("c1").lt(lit(true)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1599,7 +1851,7 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), ]); - let mut required_columns = RequiredStatColumns::new(); + let mut required_columns = RequiredColumns::new(); // c1 < 1 and (c2 = 2 or c2 = 3) let expr = col("c1") .lt(lit(1)) @@ -1615,7 +1867,7 @@ mod tests { ( phys_expr::Column::new("c1", 0), StatisticsType::Min, - c1_min_field + c1_min_field.with_nullable(true) // could be nullable if stats are not present ) ); // c2 = 2 should add c2_min and c2_max @@ -1625,7 +1877,7 @@ mod tests { ( phys_expr::Column::new("c2", 1), StatisticsType::Min, - c2_min_field + c2_min_field.with_nullable(true) // could be nullable if stats are not present ) ); let c2_max_field = Field::new("c2_max", DataType::Int32, false); @@ -1634,7 +1886,7 @@ mod tests { ( phys_expr::Column::new("c2", 1), StatisticsType::Max, - c2_max_field + c2_max_field.with_nullable(true) // could be nullable if stats are not present ) ); // c2 = 3 shouldn't add any new statistics fields @@ -1656,11 +1908,8 @@ mod tests { false, )); let expected_expr = "c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_min@0 <= 2 AND 2 <= c1_max@1 OR c1_min@0 <= 3 AND 3 <= c1_max@1"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1675,11 +1924,8 @@ mod tests { // test c1 in() let expr = Expr::InList(InList::new(Box::new(col("c1")), vec![], false)); let expected_expr = "true"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1700,11 +1946,8 @@ mod tests { let expected_expr = "(c1_min@0 != 1 OR 1 != c1_max@1) \ AND (c1_min@0 != 2 OR 2 != c1_max@1) \ AND (c1_min@0 != 3 OR 3 != c1_max@1)"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1718,20 +1961,14 @@ mod tests { // test column on the left let expr = cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(1)))); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(ScalarValue::Int64(Some(1))).eq(cast(col("c1"), DataType::Int64)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); let expected_expr = "TRY_CAST(c1_max@0 AS Int64) > 1"; @@ -1739,21 +1976,15 @@ mod tests { // test column on the left let expr = try_cast(col("c1"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(1)))); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(ScalarValue::Int64(Some(1))).lt(try_cast(col("c1"), DataType::Int64)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1773,11 +2004,8 @@ mod tests { false, )); let expected_expr = "CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64)"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); let expr = Expr::InList(InList::new( @@ -1793,11 +2021,8 @@ mod tests { "(CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64)) \ AND (CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64)) \ AND (CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64))"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1811,54 +2036,52 @@ mod tests { DataType::Decimal128(9, 2), true, )])); - // s1 > 5 - let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2))); - let expr = logical2physical(&expr, &schema); - // If the data is written by spark, the physical data type is INT32 in the parquet - // So we use the INT32 type of statistic. - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_i32( - vec![Some(0), Some(4), None, Some(3)], // min - vec![Some(5), Some(6), Some(4), None], // max + + prune_with_expr( + // s1 > 5 + col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2))), + &schema, + // If the data is written by spark, the physical data type is INT32 in the parquet + // So we use the INT32 type of statistic. + &TestStatistics::new().with( + "s1", + ContainerStats::new_i32( + vec![Some(0), Some(4), None, Some(3)], // min + vec![Some(5), Some(6), Some(4), None], // max + ), ), + &[false, true, false, true], ); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, false, true]; - assert_eq!(result, expected); - // with cast column to other type - let expr = cast(col("s1"), DataType::Decimal128(14, 3)) - .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))); - let expr = logical2physical(&expr, &schema); - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_i32( - vec![Some(0), Some(4), None, Some(3)], // min - vec![Some(5), Some(6), Some(4), None], // max + prune_with_expr( + // with cast column to other type + cast(col("s1"), DataType::Decimal128(14, 3)) + .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))), + &schema, + &TestStatistics::new().with( + "s1", + ContainerStats::new_i32( + vec![Some(0), Some(4), None, Some(3)], // min + vec![Some(5), Some(6), Some(4), None], // max + ), ), + &[false, true, false, true], ); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, false, true]; - assert_eq!(result, expected); - // with try cast column to other type - let expr = try_cast(col("s1"), DataType::Decimal128(14, 3)) - .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))); - let expr = logical2physical(&expr, &schema); - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_i32( - vec![Some(0), Some(4), None, Some(3)], // min - vec![Some(5), Some(6), Some(4), None], // max + prune_with_expr( + // with try cast column to other type + try_cast(col("s1"), DataType::Decimal128(14, 3)) + .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))), + &schema, + &TestStatistics::new().with( + "s1", + ContainerStats::new_i32( + vec![Some(0), Some(4), None, Some(3)], // min + vec![Some(5), Some(6), Some(4), None], // max + ), ), + &[false, true, false, true], ); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, false, true]; - assert_eq!(result, expected); // decimal(18,2) let schema = Arc::new(Schema::new(vec![Field::new( @@ -1866,22 +2089,21 @@ mod tests { DataType::Decimal128(18, 2), true, )])); - // s1 > 5 - let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 18, 2))); - let expr = logical2physical(&expr, &schema); - // If the data is written by spark, the physical data type is INT64 in the parquet - // So we use the INT32 type of statistic. - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_i64( - vec![Some(0), Some(4), None, Some(3)], // min - vec![Some(5), Some(6), Some(4), None], // max + prune_with_expr( + // s1 > 5 + col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 18, 2))), + &schema, + // If the data is written by spark, the physical data type is INT64 in the parquet + // So we use the INT32 type of statistic. + &TestStatistics::new().with( + "s1", + ContainerStats::new_i64( + vec![Some(0), Some(4), None, Some(3)], // min + vec![Some(5), Some(6), Some(4), None], // max + ), ), + &[false, true, false, true], ); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, false, true]; - assert_eq!(result, expected); // decimal(23,2) let schema = Arc::new(Schema::new(vec![Field::new( @@ -1889,22 +2111,22 @@ mod tests { DataType::Decimal128(23, 2), true, )])); - // s1 > 5 - let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 23, 2))); - let expr = logical2physical(&expr, &schema); - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_decimal128( - vec![Some(0), Some(400), None, Some(300)], // min - vec![Some(500), Some(600), Some(400), None], // max - 23, - 2, + + prune_with_expr( + // s1 > 5 + col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 23, 2))), + &schema, + &TestStatistics::new().with( + "s1", + ContainerStats::new_decimal128( + vec![Some(0), Some(400), None, Some(300)], // min + vec![Some(500), Some(600), Some(400), None], // max + 23, + 2, + ), ), + &[false, true, false, true], ); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, false, true]; - assert_eq!(result, expected); } #[test] @@ -1914,10 +2136,6 @@ mod tests { Field::new("s2", DataType::Int32, true), ])); - // Prune using s2 > 5 - let expr = col("s2").gt(lit(5)); - let expr = logical2physical(&expr, &schema); - let statistics = TestStatistics::new().with( "s2", ContainerStats::new_i32( @@ -1925,53 +2143,50 @@ mod tests { vec![Some(5), Some(6), None, None], // max ), ); + prune_with_expr( + // Prune using s2 > 5 + col("s2").gt(lit(5)), + &schema, + &statistics, + // s2 [0, 5] ==> no rows should pass + // s2 [4, 6] ==> some rows could pass + // No stats for s2 ==> some rows could pass + // s2 [3, None] (null max) ==> some rows could pass + &[false, true, true, true], + ); - // s2 [0, 5] ==> no rows should pass - // s2 [4, 6] ==> some rows could pass - // No stats for s2 ==> some rows could pass - // s2 [3, None] (null max) ==> some rows could pass - - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, true, true]; - assert_eq!(result, expected); - - // filter with cast - let expr = cast(col("s2"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(5)))); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![false, true, true, true]; - assert_eq!(result, expected); + prune_with_expr( + // filter with cast + cast(col("s2"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(5)))), + &schema, + &statistics, + &[false, true, true, true], + ); } #[test] fn prune_not_eq_data() { let schema = Arc::new(Schema::new(vec![Field::new("s1", DataType::Utf8, true)])); - // Prune using s2 != 'M' - let expr = col("s1").not_eq(lit("M")); - let expr = logical2physical(&expr, &schema); - - let statistics = TestStatistics::new().with( - "s1", - ContainerStats::new_utf8( - vec![Some("A"), Some("A"), Some("N"), Some("M"), None, Some("A")], // min - vec![Some("Z"), Some("L"), Some("Z"), Some("M"), None, None], // max + prune_with_expr( + // Prune using s2 != 'M' + col("s1").not_eq(lit("M")), + &schema, + &TestStatistics::new().with( + "s1", + ContainerStats::new_utf8( + vec![Some("A"), Some("A"), Some("N"), Some("M"), None, Some("A")], // min + vec![Some("Z"), Some("L"), Some("Z"), Some("M"), None, None], // max + ), ), + // s1 [A, Z] ==> might have values that pass predicate + // s1 [A, L] ==> all rows pass the predicate + // s1 [N, Z] ==> all rows pass the predicate + // s1 [M, M] ==> all rows do not pass the predicate + // No stats for s2 ==> some rows could pass + // s2 [3, None] (null max) ==> some rows could pass + &[true, true, true, false, true, true], ); - - // s1 [A, Z] ==> might have values that pass predicate - // s1 [A, L] ==> all rows pass the predicate - // s1 [N, Z] ==> all rows pass the predicate - // s1 [M, M] ==> all rows do not pass the predicate - // No stats for s2 ==> some rows could pass - // s2 [3, None] (null max) ==> some rows could pass - - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - let expected = vec![true, true, true, false, true, true]; - assert_eq!(result, expected); } /// Creates setup for boolean chunk pruning @@ -2010,69 +2225,75 @@ mod tests { fn prune_bool_const_expr() { let (schema, statistics, _, _) = bool_setup(); - // true - let expr = lit(true); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, vec![true, true, true, true, true]); + prune_with_expr( + // true + lit(true), + &schema, + &statistics, + &[true, true, true, true, true], + ); - // false - // constant literals that do NOT refer to any columns are currently not evaluated at all, hence the result is - // "all true" - let expr = lit(false); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, vec![true, true, true, true, true]); + prune_with_expr( + // false + // constant literals that do NOT refer to any columns are currently not evaluated at all, hence the result is + // "all true" + lit(false), + &schema, + &statistics, + &[true, true, true, true, true], + ); } #[test] fn prune_bool_column() { let (schema, statistics, expected_true, _) = bool_setup(); - // b1 - let expr = col("b1"); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_true); + prune_with_expr( + // b1 + col("b1"), + &schema, + &statistics, + &expected_true, + ); } #[test] fn prune_bool_not_column() { let (schema, statistics, _, expected_false) = bool_setup(); - // !b1 - let expr = col("b1").not(); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_false); + prune_with_expr( + // !b1 + col("b1").not(), + &schema, + &statistics, + &expected_false, + ); } #[test] fn prune_bool_column_eq_true() { let (schema, statistics, expected_true, _) = bool_setup(); - // b1 = true - let expr = col("b1").eq(lit(true)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_true); + prune_with_expr( + // b1 = true + col("b1").eq(lit(true)), + &schema, + &statistics, + &expected_true, + ); } #[test] fn prune_bool_not_column_eq_true() { let (schema, statistics, _, expected_false) = bool_setup(); - // !b1 = true - let expr = col("b1").not().eq(lit(true)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_false); + prune_with_expr( + // !b1 = true + col("b1").not().eq(lit(true)), + &schema, + &statistics, + &expected_false, + ); } /// Creates a setup for chunk pruning, modeling a int32 column "i" @@ -2107,21 +2328,18 @@ mod tests { // i [-11, -1] ==> no rows can pass (not keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> unknown (must keep) - let expected_ret = vec![true, true, false, true, true]; + let expected_ret = &[true, true, false, true, true]; // i > 0 - let expr = col("i").gt(lit(0)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr(col("i").gt(lit(0)), &schema, &statistics, expected_ret); // -i < 0 - let expr = Expr::Negative(Box::new(col("i"))).lt(lit(0)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + Expr::Negative(Box::new(col("i"))).lt(lit(0)), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2134,21 +2352,23 @@ mod tests { // i [-11, -1] ==> all rows must pass (must keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> no rows can pass (not keep) - let expected_ret = vec![true, false, true, true, false]; + let expected_ret = &[true, false, true, true, false]; - // i <= 0 - let expr = col("i").lt_eq(lit(0)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i <= 0 + col("i").lt_eq(lit(0)), + &schema, + &statistics, + expected_ret, + ); - // -i >= 0 - let expr = Expr::Negative(Box::new(col("i"))).gt_eq(lit(0)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // -i >= 0 + Expr::Negative(Box::new(col("i"))).gt_eq(lit(0)), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2161,37 +2381,39 @@ mod tests { // i [-11, -1] ==> no rows could pass in theory (conservatively keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> no rows can pass (conservatively keep) - let expected_ret = vec![true, true, true, true, true]; + let expected_ret = &[true, true, true, true, true]; - // cast(i as utf8) <= 0 - let expr = cast(col("i"), DataType::Utf8).lt_eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // cast(i as utf8) <= 0 + cast(col("i"), DataType::Utf8).lt_eq(lit("0")), + &schema, + &statistics, + expected_ret, + ); - // try_cast(i as utf8) <= 0 - let expr = try_cast(col("i"), DataType::Utf8).lt_eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // try_cast(i as utf8) <= 0 + try_cast(col("i"), DataType::Utf8).lt_eq(lit("0")), + &schema, + &statistics, + expected_ret, + ); - // cast(-i as utf8) >= 0 - let expr = - cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // cast(-i as utf8) >= 0 + cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), + &schema, + &statistics, + expected_ret, + ); - // try_cast(-i as utf8) >= 0 - let expr = - try_cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // try_cast(-i as utf8) >= 0 + try_cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2204,14 +2426,15 @@ mod tests { // i [-11, -1] ==> no rows can pass (not keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> no rows can pass (not keep) - let expected_ret = vec![true, false, false, true, false]; + let expected_ret = &[true, false, false, true, false]; - // i = 0 - let expr = col("i").eq(lit(0)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i = 0 + col("i").eq(lit(0)), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2224,19 +2447,21 @@ mod tests { // i [-11, -1] ==> no rows can pass (not keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> no rows can pass (not keep) - let expected_ret = vec![true, false, false, true, false]; + let expected_ret = &[true, false, false, true, false]; - let expr = cast(col("i"), DataType::Int64).eq(lit(0i64)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + cast(col("i"), DataType::Int64).eq(lit(0i64)), + &schema, + &statistics, + expected_ret, + ); - let expr = try_cast(col("i"), DataType::Int64).eq(lit(0i64)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + try_cast(col("i"), DataType::Int64).eq(lit(0i64)), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2252,13 +2477,14 @@ mod tests { // i [-11, -1] ==> no rows can pass (could keep) // i [NULL, NULL] ==> unknown (keep) // i [1, NULL] ==> no rows can pass (could keep) - let expected_ret = vec![true, true, true, true, true]; + let expected_ret = &[true, true, true, true, true]; - let expr = cast(col("i"), DataType::Utf8).eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + cast(col("i"), DataType::Utf8).eq(lit("0")), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2271,21 +2497,23 @@ mod tests { // i [-11, -1] ==> no rows can pass (not keep) // i [NULL, NULL] ==> unknown (must keep) // i [1, NULL] ==> all rows must pass (must keep) - let expected_ret = vec![true, true, false, true, true]; + let expected_ret = &[true, true, false, true, true]; - // i > -1 - let expr = col("i").gt(lit(-1)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i > -1 + col("i").gt(lit(-1)), + &schema, + &statistics, + expected_ret, + ); - // -i < 1 - let expr = Expr::Negative(Box::new(col("i"))).lt(lit(1)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // -i < 1 + Expr::Negative(Box::new(col("i"))).lt(lit(1)), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2294,14 +2522,15 @@ mod tests { // Expression "i IS NULL" when there are no null statistics, // should all be kept - let expected_ret = vec![true, true, true, true, true]; + let expected_ret = &[true, true, true, true, true]; - // i IS NULL, no null statistics - let expr = col("i").is_null(); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i IS NULL, no null statistics + col("i").is_null(), + &schema, + &statistics, + expected_ret, + ); // provide null counts for each column let statistics = statistics.with_null_counts( @@ -2315,51 +2544,55 @@ mod tests { ], ); - let expected_ret = vec![false, true, true, true, false]; + let expected_ret = &[false, true, true, true, false]; - // i IS NULL, with actual null statistcs - let expr = col("i").is_null(); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i IS NULL, with actual null statistcs + col("i").is_null(), + &schema, + &statistics, + expected_ret, + ); } #[test] fn prune_cast_column_scalar() { // The data type of column i is INT32 let (schema, statistics) = int32_setup(); - let expected_ret = vec![true, true, false, true, true]; + let expected_ret = &[true, true, false, true, true]; - // i > int64(0) - let expr = col("i").gt(cast(lit(ScalarValue::Int64(Some(0))), DataType::Int32)); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // i > int64(0) + col("i").gt(cast(lit(ScalarValue::Int64(Some(0))), DataType::Int32)), + &schema, + &statistics, + expected_ret, + ); - // cast(i as int64) > int64(0) - let expr = cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // cast(i as int64) > int64(0) + cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))), + &schema, + &statistics, + expected_ret, + ); - // try_cast(i as int64) > int64(0) - let expr = - try_cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // try_cast(i as int64) > int64(0) + try_cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))), + &schema, + &statistics, + expected_ret, + ); - // `-cast(i as int64) < 0` convert to `cast(i as int64) > -0` - let expr = Expr::Negative(Box::new(cast(col("i"), DataType::Int64))) - .lt(lit(ScalarValue::Int64(Some(0)))); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, schema).unwrap(); - let result = p.prune(&statistics).unwrap(); - assert_eq!(result, expected_ret); + prune_with_expr( + // `-cast(i as int64) < 0` convert to `cast(i as int64) > -0` + Expr::Negative(Box::new(cast(col("i"), DataType::Int64))) + .lt(lit(ScalarValue::Int64(Some(0)))), + &schema, + &statistics, + expected_ret, + ); } #[test] @@ -2440,10 +2673,464 @@ mod tests { // TODO: add other negative test for other case and op } + #[test] + fn prune_with_contained_one_column() { + let schema = Arc::new(Schema::new(vec![Field::new("s1", DataType::Utf8, true)])); + + // Model having information like a bloom filter for s1 + let statistics = TestStatistics::new() + .with_contained( + "s1", + [ScalarValue::from("foo")], + [ + // container 0 known to only contain "foo"", + Some(true), + // container 1 known to not contain "foo" + Some(false), + // container 2 unknown about "foo" + None, + // container 3 known to only contain "foo" + Some(true), + // container 4 known to not contain "foo" + Some(false), + // container 5 unknown about "foo" + None, + // container 6 known to only contain "foo" + Some(true), + // container 7 known to not contain "foo" + Some(false), + // container 8 unknown about "foo" + None, + ], + ) + .with_contained( + "s1", + [ScalarValue::from("bar")], + [ + // containers 0,1,2 known to only contain "bar" + Some(true), + Some(true), + Some(true), + // container 3,4,5 known to not contain "bar" + Some(false), + Some(false), + Some(false), + // container 6,7,8 unknown about "bar" + None, + None, + None, + ], + ) + .with_contained( + // the way the tests are setup, this data is + // consulted if the "foo" and "bar" are being checked at the same time + "s1", + [ScalarValue::from("foo"), ScalarValue::from("bar")], + [ + // container 0,1,2 unknown about ("foo, "bar") + None, + None, + None, + // container 3,4,5 known to contain only either "foo" and "bar" + Some(true), + Some(true), + Some(true), + // container 6,7,8 known to contain neither "foo" and "bar" + Some(false), + Some(false), + Some(false), + ], + ); + + // s1 = 'foo' + prune_with_expr( + col("s1").eq(lit("foo")), + &schema, + &statistics, + // rule out containers ('false) where we know foo is not present + &[true, false, true, true, false, true, true, false, true], + ); + + // s1 = 'bar' + prune_with_expr( + col("s1").eq(lit("bar")), + &schema, + &statistics, + // rule out containers where we know bar is not present + &[true, true, true, false, false, false, true, true, true], + ); + + // s1 = 'baz' (unknown value) + prune_with_expr( + col("s1").eq(lit("baz")), + &schema, + &statistics, + // can't rule out anything + &[true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' AND s1 = 'bar' + prune_with_expr( + col("s1").eq(lit("foo")).and(col("s1").eq(lit("bar"))), + &schema, + &statistics, + // logically this predicate can't possibly be true (the column can't + // take on both values) but we could rule it out if the stats tell + // us that both values are not present + &[true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' OR s1 = 'bar' + prune_with_expr( + col("s1").eq(lit("foo")).or(col("s1").eq(lit("bar"))), + &schema, + &statistics, + // can rule out containers that we know contain neither foo nor bar + &[true, true, true, true, true, true, false, false, false], + ); + + // s1 = 'foo' OR s1 = 'baz' + prune_with_expr( + col("s1").eq(lit("foo")).or(col("s1").eq(lit("baz"))), + &schema, + &statistics, + // can't rule out anything container + &[true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' OR s1 = 'bar' OR s1 = 'baz' + prune_with_expr( + col("s1") + .eq(lit("foo")) + .or(col("s1").eq(lit("bar"))) + .or(col("s1").eq(lit("baz"))), + &schema, + &statistics, + // can rule out any containers based on knowledge of s1 and `foo`, + // `bar` and (`foo`, `bar`) + &[true, true, true, true, true, true, true, true, true], + ); + + // s1 != foo + prune_with_expr( + col("s1").not_eq(lit("foo")), + &schema, + &statistics, + // rule out containers we know for sure only contain foo + &[false, true, true, false, true, true, false, true, true], + ); + + // s1 != bar + prune_with_expr( + col("s1").not_eq(lit("bar")), + &schema, + &statistics, + // rule out when we know for sure s1 has the value bar + &[false, false, false, true, true, true, true, true, true], + ); + + // s1 != foo AND s1 != bar + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s1").not_eq(lit("bar"))), + &schema, + &statistics, + // can rule out any container where we know s1 does not have either 'foo' or 'bar' + &[true, true, true, false, false, false, true, true, true], + ); + + // s1 != foo AND s1 != bar AND s1 != baz + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s1").not_eq(lit("bar"))) + .and(col("s1").not_eq(lit("baz"))), + &schema, + &statistics, + // can't rule out any container based on knowledge of s1,s2 + &[true, true, true, true, true, true, true, true, true], + ); + + // s1 != foo OR s1 != bar + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .or(col("s1").not_eq(lit("bar"))), + &schema, + &statistics, + // cant' rule out anything based on contains information + &[true, true, true, true, true, true, true, true, true], + ); + + // s1 != foo OR s1 != bar OR s1 != baz + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .or(col("s1").not_eq(lit("bar"))) + .or(col("s1").not_eq(lit("baz"))), + &schema, + &statistics, + // cant' rule out anything based on contains information + &[true, true, true, true, true, true, true, true, true], + ); + } + + #[test] + fn prune_with_contained_two_columns() { + let schema = Arc::new(Schema::new(vec![ + Field::new("s1", DataType::Utf8, true), + Field::new("s2", DataType::Utf8, true), + ])); + + // Model having information like bloom filters for s1 and s2 + let statistics = TestStatistics::new() + .with_contained( + "s1", + [ScalarValue::from("foo")], + [ + // container 0, s1 known to only contain "foo"", + Some(true), + // container 1, s1 known to not contain "foo" + Some(false), + // container 2, s1 unknown about "foo" + None, + // container 3, s1 known to only contain "foo" + Some(true), + // container 4, s1 known to not contain "foo" + Some(false), + // container 5, s1 unknown about "foo" + None, + // container 6, s1 known to only contain "foo" + Some(true), + // container 7, s1 known to not contain "foo" + Some(false), + // container 8, s1 unknown about "foo" + None, + ], + ) + .with_contained( + "s2", // for column s2 + [ScalarValue::from("bar")], + [ + // containers 0,1,2 s2 known to only contain "bar" + Some(true), + Some(true), + Some(true), + // container 3,4,5 s2 known to not contain "bar" + Some(false), + Some(false), + Some(false), + // container 6,7,8 s2 unknown about "bar" + None, + None, + None, + ], + ); + + // s1 = 'foo' + prune_with_expr( + col("s1").eq(lit("foo")), + &schema, + &statistics, + // rule out containers where we know s1 is not present + &[true, false, true, true, false, true, true, false, true], + ); + + // s1 = 'foo' OR s2 = 'bar' + let expr = col("s1").eq(lit("foo")).or(col("s2").eq(lit("bar"))); + prune_with_expr( + expr, + &schema, + &statistics, + // can't rule out any container (would need to prove that s1 != foo AND s2 != bar) + &[true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' AND s2 != 'bar' + prune_with_expr( + col("s1").eq(lit("foo")).and(col("s2").not_eq(lit("bar"))), + &schema, + &statistics, + // can only rule out container where we know either: + // 1. s1 doesn't have the value 'foo` or + // 2. s2 has only the value of 'bar' + &[false, false, false, true, false, true, true, false, true], + ); + + // s1 != 'foo' AND s2 != 'bar' + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s2").not_eq(lit("bar"))), + &schema, + &statistics, + // Can rule out any container where we know either + // 1. s1 has only the value 'foo' + // 2. s2 has only the value 'bar' + &[false, false, false, false, true, true, false, true, true], + ); + + // s1 != 'foo' AND (s2 = 'bar' OR s2 = 'baz') + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s2").eq(lit("bar")).or(col("s2").eq(lit("baz")))), + &schema, + &statistics, + // Can rule out any container where we know s1 has only the value + // 'foo'. Can't use knowledge of s2 and bar to rule out anything + &[false, true, true, false, true, true, false, true, true], + ); + + // s1 like '%foo%bar%' + prune_with_expr( + col("s1").like(lit("foo%bar%")), + &schema, + &statistics, + // cant rule out anything with information we know + &[true, true, true, true, true, true, true, true, true], + ); + + // s1 like '%foo%bar%' AND s2 = 'bar' + prune_with_expr( + col("s1") + .like(lit("foo%bar%")) + .and(col("s2").eq(lit("bar"))), + &schema, + &statistics, + // can rule out any container where we know s2 does not have the value 'bar' + &[true, true, true, false, false, false, true, true, true], + ); + + // s1 like '%foo%bar%' OR s2 = 'bar' + prune_with_expr( + col("s1").like(lit("foo%bar%")).or(col("s2").eq(lit("bar"))), + &schema, + &statistics, + // can't rule out anything (we would have to prove that both the + // like and the equality must be false) + &[true, true, true, true, true, true, true, true, true], + ); + } + + #[test] + fn prune_with_range_and_contained() { + // Setup mimics range information for i, a bloom filter for s + let schema = Arc::new(Schema::new(vec![ + Field::new("i", DataType::Int32, true), + Field::new("s", DataType::Utf8, true), + ])); + + let statistics = TestStatistics::new() + .with( + "i", + ContainerStats::new_i32( + // Container 0, 3, 6: [-5 to 5] + // Container 1, 4, 7: [10 to 20] + // Container 2, 5, 9: unknown + vec![ + Some(-5), + Some(10), + None, + Some(-5), + Some(10), + None, + Some(-5), + Some(10), + None, + ], // min + vec![ + Some(5), + Some(20), + None, + Some(5), + Some(20), + None, + Some(5), + Some(20), + None, + ], // max + ), + ) + // Add contained information about the s and "foo" + .with_contained( + "s", + [ScalarValue::from("foo")], + [ + // container 0,1,2 known to only contain "foo" + Some(true), + Some(true), + Some(true), + // container 3,4,5 known to not contain "foo" + Some(false), + Some(false), + Some(false), + // container 6,7,8 unknown about "foo" + None, + None, + None, + ], + ); + + // i = 0 and s = 'foo' + prune_with_expr( + col("i").eq(lit(0)).and(col("s").eq(lit("foo"))), + &schema, + &statistics, + // Can rule out container where we know that either: + // 1. 0 is outside the min/max range of i + // 1. s does not contain foo + // (range is false, and contained is false) + &[true, false, true, false, false, false, true, false, true], + ); + + // i = 0 and s != 'foo' + prune_with_expr( + col("i").eq(lit(0)).and(col("s").not_eq(lit("foo"))), + &schema, + &statistics, + // Can rule out containers where either: + // 1. 0 is outside the min/max range of i + // 2. s only contains foo + &[false, false, false, true, false, true, true, false, true], + ); + + // i = 0 OR s = 'foo' + prune_with_expr( + col("i").eq(lit(0)).or(col("s").eq(lit("foo"))), + &schema, + &statistics, + // in theory could rule out containers if we had min/max values for + // s as well. But in this case we don't so we can't rule out anything + &[true, true, true, true, true, true, true, true, true], + ); + } + + /// prunes the specified expr with the specified schema and statistics, and + /// ensures it returns expected. + /// + /// `expected` is a vector of bools, where true means the row group should + /// be kept, and false means it should be pruned. + /// + // TODO refactor other tests to use this to reduce boiler plate + fn prune_with_expr( + expr: Expr, + schema: &SchemaRef, + statistics: &TestStatistics, + expected: &[bool], + ) { + println!("Pruning with expr: {}", expr); + let expr = logical2physical(&expr, schema); + let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + let result = p.prune(statistics).unwrap(); + assert_eq!(result, expected); + } + fn test_build_predicate_expression( expr: &Expr, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, ) -> Arc { let expr = logical2physical(expr, schema); build_predicate_expression(&expr, schema, required_columns) diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index b0ae199a2da4..e49b358608aa 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -19,19 +19,18 @@ //! order-preserving variants when it is helpful; either in terms of //! performance or to accommodate unbounded streams by fixing the pipeline. +use std::borrow::Cow; use std::sync::Arc; +use super::utils::is_repartition; use crate::error::Result; -use crate::physical_optimizer::utils::{is_coalesce_partitions, is_sort, ExecTree}; +use crate::physical_optimizer::utils::{is_coalesce_partitions, is_sort}; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use super::utils::is_repartition; - use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; -use datafusion_physical_expr::utils::ordering_satisfy; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_physical_plan::unbounded_output; /// For a given `plan`, this object carries the information one needs from its @@ -41,159 +40,157 @@ use datafusion_physical_plan::unbounded_output; #[derive(Debug, Clone)] pub(crate) struct OrderPreservationContext { pub(crate) plan: Arc, - ordering_onwards: Vec>, + ordering_connection: bool, + children_nodes: Vec, } impl OrderPreservationContext { - /// Creates a "default" order-preservation context. + /// Creates an empty context tree. Each node has `false` connections. pub fn new(plan: Arc) -> Self { - let length = plan.children().len(); - OrderPreservationContext { + let children = plan.children(); + Self { plan, - ordering_onwards: vec![None; length], + ordering_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } /// Creates a new order-preservation context from those of children nodes. - pub fn new_from_children_nodes( - children_nodes: Vec, - parent_plan: Arc, - ) -> Result { - let children_plans = children_nodes - .iter() - .map(|item| item.plan.clone()) - .collect(); - let ordering_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, item)| { - // `ordering_onwards` tree keeps track of executors that maintain - // ordering, (or that can maintain ordering with the replacement of - // its variant) - let plan = item.plan; - let ordering_onwards = item.ordering_onwards; - if plan.children().is_empty() { - // Plan has no children, there is nothing to propagate. - None - } else if ordering_onwards[0].is_none() - && ((is_repartition(&plan) && !plan.maintains_input_order()[0]) - || (is_coalesce_partitions(&plan) - && plan.children()[0].output_ordering().is_some())) - { - Some(ExecTree::new(plan, idx, vec![])) - } else { - let children = ordering_onwards - .into_iter() - .flatten() - .filter(|item| { - // Only consider operators that maintains ordering - plan.maintains_input_order()[item.idx] - || is_coalesce_partitions(&plan) - || is_repartition(&plan) - }) - .collect::>(); - if children.is_empty() { - None - } else { - Some(ExecTree::new(plan, idx, children)) - } - } - }) - .collect(); - let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); - Ok(OrderPreservationContext { - plan, - ordering_onwards, - }) - } + pub fn update_children(mut self) -> Result { + for node in self.children_nodes.iter_mut() { + let plan = node.plan.clone(); + let children = plan.children(); + let maintains_input_order = plan.maintains_input_order(); + let inspect_child = |idx| { + maintains_input_order[idx] + || is_coalesce_partitions(&plan) + || is_repartition(&plan) + }; + + // We cut the path towards nodes that do not maintain ordering. + for (idx, c) in node.children_nodes.iter_mut().enumerate() { + c.ordering_connection &= inspect_child(idx); + } + + node.ordering_connection = if children.is_empty() { + false + } else if !node.children_nodes[0].ordering_connection + && ((is_repartition(&plan) && !maintains_input_order[0]) + || (is_coalesce_partitions(&plan) + && children[0].output_ordering().is_some())) + { + // We either have a RepartitionExec or a CoalescePartitionsExec + // and they lose their input ordering, so initiate connection: + true + } else { + // Maintain connection if there is a child with a connection, + // and operator can possibly maintain that connection (either + // in its current form or when we replace it with the corresponding + // order preserving operator). + node.children_nodes + .iter() + .enumerate() + .any(|(idx, c)| c.ordering_connection && inspect_child(idx)) + } + } - /// Computes order-preservation contexts for every child of the plan. - pub fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(|child| OrderPreservationContext::new(child)) - .collect() + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); + self.ordering_connection = false; + Ok(self) } } impl TreeNode for OrderPreservationContext { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - OrderPreservationContext::new_from_children_nodes(children_nodes, self.plan) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } -/// Calculates the updated plan by replacing executors that lose ordering -/// inside the `ExecTree` with their order-preserving variants. This will +/// Calculates the updated plan by replacing operators that lose ordering +/// inside `sort_input` with their order-preserving variants. This will /// generate an alternative plan, which will be accepted or rejected later on /// depending on whether it helps us remove a `SortExec`. fn get_updated_plan( - exec_tree: &ExecTree, + mut sort_input: OrderPreservationContext, // Flag indicating that it is desirable to replace `RepartitionExec`s with // `SortPreservingRepartitionExec`s: is_spr_better: bool, // Flag indicating that it is desirable to replace `CoalescePartitionsExec`s // with `SortPreservingMergeExec`s: is_spm_better: bool, -) -> Result> { - let plan = exec_tree.plan.clone(); +) -> Result { + let updated_children = sort_input + .children_nodes + .clone() + .into_iter() + .map(|item| { + // Update children and their descendants in the given tree if the connection is open: + if item.ordering_connection { + get_updated_plan(item, is_spr_better, is_spm_better) + } else { + Ok(item) + } + }) + .collect::>>()?; - let mut children = plan.children(); - // Update children and their descendants in the given tree: - for item in &exec_tree.children { - children[item.idx] = get_updated_plan(item, is_spr_better, is_spm_better)?; - } - // Construct the plan with updated children: - let mut plan = plan.with_new_children(children)?; + sort_input.plan = sort_input + .plan + .with_new_children(updated_children.iter().map(|c| c.plan.clone()).collect())?; + sort_input.ordering_connection = false; + sort_input.children_nodes = updated_children; // When a `RepartitionExec` doesn't preserve ordering, replace it with - // a `SortPreservingRepartitionExec` if appropriate: - if is_repartition(&plan) && !plan.maintains_input_order()[0] && is_spr_better { - let child = plan.children()[0].clone(); - plan = Arc::new( - RepartitionExec::try_new(child, plan.output_partitioning())? - .with_preserve_order(true), - ) as _ - } - // When the input of a `CoalescePartitionsExec` has an ordering, replace it - // with a `SortPreservingMergeExec` if appropriate: - if is_coalesce_partitions(&plan) - && plan.children()[0].output_ordering().is_some() - && is_spm_better + // a sort-preserving variant if appropriate: + if is_repartition(&sort_input.plan) + && !sort_input.plan.maintains_input_order()[0] + && is_spr_better { - let child = plan.children()[0].clone(); - plan = Arc::new(SortPreservingMergeExec::new( - child.output_ordering().unwrap_or(&[]).to_vec(), - child, - )) as _ + let child = sort_input.plan.children().swap_remove(0); + let repartition = + RepartitionExec::try_new(child, sort_input.plan.output_partitioning())? + .with_preserve_order(); + sort_input.plan = Arc::new(repartition) as _; + sort_input.children_nodes[0].ordering_connection = true; + } else if is_coalesce_partitions(&sort_input.plan) && is_spm_better { + // When the input of a `CoalescePartitionsExec` has an ordering, replace it + // with a `SortPreservingMergeExec` if appropriate: + if let Some(ordering) = sort_input.children_nodes[0] + .plan + .output_ordering() + .map(|o| o.to_vec()) + { + // Now we can mutate `new_node.children_nodes` safely + let child = sort_input.children_nodes.clone().swap_remove(0); + sort_input.plan = + Arc::new(SortPreservingMergeExec::new(ordering, child.plan)) as _; + sort_input.children_nodes[0].ordering_connection = true; + } } - Ok(plan) + + Ok(sort_input) } /// The `replace_with_order_preserving_variants` optimizer sub-rule tries to @@ -211,11 +208,11 @@ fn get_updated_plan( /// /// The algorithm flow is simply like this: /// 1. Visit nodes of the physical plan bottom-up and look for `SortExec` nodes. -/// 1_1. During the traversal, build an `ExecTree` to keep track of operators -/// that maintain ordering (or can maintain ordering when replaced by an -/// order-preserving variant) until a `SortExec` is found. +/// 1_1. During the traversal, keep track of operators that maintain ordering +/// (or can maintain ordering when replaced by an order-preserving variant) until +/// a `SortExec` is found. /// 2. When a `SortExec` is found, update the child of the `SortExec` by replacing -/// operators that do not preserve ordering in the `ExecTree` with their order +/// operators that do not preserve ordering in the tree with their order /// preserving variants. /// 3. Check if the `SortExec` is still necessary in the updated plan by comparing /// its input ordering with the output ordering it imposes. We do this because @@ -239,87 +236,148 @@ pub(crate) fn replace_with_order_preserving_variants( is_spm_better: bool, config: &ConfigOptions, ) -> Result> { - let plan = &requirements.plan; - let ordering_onwards = &requirements.ordering_onwards; - if is_sort(plan) { - let exec_tree = if let Some(exec_tree) = &ordering_onwards[0] { - exec_tree - } else { - return Ok(Transformed::No(requirements)); - }; - // For unbounded cases, replace with the order-preserving variant in - // any case, as doing so helps fix the pipeline. - // Also do the replacement if opted-in via config options. - let use_order_preserving_variant = - config.optimizer.bounded_order_preserving_variants || unbounded_output(plan); - let updated_sort_input = get_updated_plan( - exec_tree, - is_spr_better || use_order_preserving_variant, - is_spm_better || use_order_preserving_variant, - )?; - // If this sort is unnecessary, we should remove it and update the plan: - if ordering_satisfy( - updated_sort_input.output_ordering(), - plan.output_ordering(), - || updated_sort_input.equivalence_properties(), - || updated_sort_input.ordering_equivalence_properties(), - ) { - return Ok(Transformed::Yes(OrderPreservationContext { - plan: updated_sort_input, - ordering_onwards: vec![None], - })); - } + let mut requirements = requirements.update_children()?; + if !(is_sort(&requirements.plan) + && requirements.children_nodes[0].ordering_connection) + { + return Ok(Transformed::No(requirements)); } - Ok(Transformed::No(requirements)) + // For unbounded cases, replace with the order-preserving variant in + // any case, as doing so helps fix the pipeline. + // Also do the replacement if opted-in via config options. + let use_order_preserving_variant = + config.optimizer.prefer_existing_sort || unbounded_output(&requirements.plan); + + let mut updated_sort_input = get_updated_plan( + requirements.children_nodes.clone().swap_remove(0), + is_spr_better || use_order_preserving_variant, + is_spm_better || use_order_preserving_variant, + )?; + + // If this sort is unnecessary, we should remove it and update the plan: + if updated_sort_input + .plan + .equivalence_properties() + .ordering_satisfy(requirements.plan.output_ordering().unwrap_or(&[])) + { + for child in updated_sort_input.children_nodes.iter_mut() { + child.ordering_connection = false; + } + Ok(Transformed::Yes(updated_sort_input)) + } else { + for child in requirements.children_nodes.iter_mut() { + child.ordering_connection = false; + } + Ok(Transformed::Yes(requirements)) + } } #[cfg(test)] mod tests { use super::*; - use crate::prelude::SessionConfig; - use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::listing::PartitionedFile; use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; - use crate::physical_plan::filter::FilterExec; use crate::physical_plan::joins::{HashJoinExec, PartitionMode}; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; - use crate::physical_plan::{displayable, Partitioning}; + use crate::physical_plan::{displayable, get_plan_string, Partitioning}; + use crate::prelude::SessionConfig; + use crate::test::TestStreamPartition; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::TreeNode; use datafusion_common::{Result, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::{self, col, Column}; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_plan::streaming::StreamingTableExec; - use arrow::compute::SortOptions; - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use rstest::rstest; - /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts the plan - /// against the original and expected plans. + /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts + /// the plan against the original and expected plans for both bounded and + /// unbounded cases. /// - /// `$EXPECTED_PLAN_LINES`: input plan - /// `$EXPECTED_OPTIMIZED_PLAN_LINES`: optimized plan - /// `$PLAN`: the plan to optimized - /// `$ALLOW_BOUNDED`: whether to allow the plan to be optimized for bounded cases - macro_rules! assert_optimized { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr) => { + /// # Parameters + /// + /// * `EXPECTED_UNBOUNDED_PLAN_LINES`: Expected input unbounded plan. + /// * `EXPECTED_BOUNDED_PLAN_LINES`: Expected input bounded plan. + /// * `EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES`: Optimized plan, which is + /// the same regardless of the value of the `prefer_existing_sort` flag. + /// * `EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES`: Optimized plan when the flag + /// `prefer_existing_sort` is `false` for bounded cases. + /// * `EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES`: Optimized plan + /// when the flag `prefer_existing_sort` is `true` for bounded cases. + /// * `$PLAN`: The plan to optimize. + /// * `$SOURCE_UNBOUNDED`: Whether the given plan contains an unbounded source. + macro_rules! assert_optimized_in_all_boundedness_situations { + ($EXPECTED_UNBOUNDED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PLAN_LINES: expr, $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $SOURCE_UNBOUNDED: expr) => { + if $SOURCE_UNBOUNDED { + assert_optimized_prefer_sort_on_off!( + $EXPECTED_UNBOUNDED_PLAN_LINES, + $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, + $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, + $PLAN + ); + } else { + assert_optimized_prefer_sort_on_off!( + $EXPECTED_BOUNDED_PLAN_LINES, + $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES, + $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, + $PLAN + ); + } + }; + } + + /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts + /// the plan against the original and expected plans. + /// + /// # Parameters + /// + /// * `$EXPECTED_PLAN_LINES`: Expected input plan. + /// * `EXPECTED_OPTIMIZED_PLAN_LINES`: Optimized plan when the flag + /// `prefer_existing_sort` is `false`. + /// * `EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES`: Optimized plan when + /// the flag `prefer_existing_sort` is `true`. + /// * `$PLAN`: The plan to optimize. + macro_rules! assert_optimized_prefer_sort_on_off { + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr) => { assert_optimized!( $EXPECTED_PLAN_LINES, $EXPECTED_OPTIMIZED_PLAN_LINES, - $PLAN, + $PLAN.clone(), false ); + assert_optimized!( + $EXPECTED_PLAN_LINES, + $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, + $PLAN, + true + ); }; - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $ALLOW_BOUNDED: expr) => { + } + + /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts + /// the plan against the original and expected plans. + /// + /// # Parameters + /// + /// * `$EXPECTED_PLAN_LINES`: Expected input plan. + /// * `$EXPECTED_OPTIMIZED_PLAN_LINES`: Expected optimized plan. + /// * `$PLAN`: The plan to optimize. + /// * `$PREFER_EXISTING_SORT`: Value of the `prefer_existing_sort` flag. + macro_rules! assert_optimized { + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $PREFER_EXISTING_SORT: expr) => { let physical_plan = $PLAN; let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -335,8 +393,7 @@ mod tests { let expected_optimized_lines: Vec<&str> = $EXPECTED_OPTIMIZED_PLAN_LINES.iter().map(|s| *s).collect(); // Run the rule top-down - // let optimized_physical_plan = physical_plan.transform_down(&replace_repartition_execs)?; - let config = SessionConfig::new().with_bounded_order_preserving_variants($ALLOW_BOUNDED); + let config = SessionConfig::new().with_prefer_existing_sort($PREFER_EXISTING_SORT); let plan_with_pipeline_fixer = OrderPreservationContext::new(physical_plan); let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options()))?; let optimized_physical_plan = parallel.plan; @@ -350,147 +407,351 @@ mod tests { }; } + #[rstest] #[tokio::test] // Searches for a simple sort and a repartition just after it, the second repartition with 1 input partition should not be affected - async fn test_replace_multiple_input_repartition_1() -> Result<()> { + async fn test_replace_multiple_input_repartition_1( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_inter_children_change_only() -> Result<()> { + async fn test_with_inter_children_change_only( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr_default("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let sort = sort_exec( - vec![sort_expr_default("a", &schema)], + vec![sort_expr_default("a", &coalesce_partitions.schema())], coalesce_partitions, false, ); let repartition_rr2 = repartition_exec_round_robin(sort); let repartition_hash2 = repartition_exec_hash(repartition_rr2); - let filter = filter_exec(repartition_hash2, &schema); - let sort2 = sort_exec(vec![sort_expr_default("a", &schema)], filter, true); + let filter = filter_exec(repartition_hash2); + let sort2 = + sort_exec(vec![sort_expr_default("a", &filter.schema())], filter, true); - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("a", &schema)], sort2); + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("a", &sort2.schema())], + sort2, + ); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", + ]; + let expected_input_bounded = [ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC]", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " SortExec: expr=[a@0 ASC]", " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", ]; - let expected_optimized = [ + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortPreservingMergeExec: [a@0 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", + ]; + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ "SortPreservingMergeExec: [a@0 ASC]", - " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " SortPreservingMergeExec: [a@0 ASC]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_replace_multiple_input_repartition_2() -> Result<()> { + async fn test_replace_multiple_input_repartition_2( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); - let filter = filter_exec(repartition_rr, &schema); + let filter = filter_exec(repartition_rr); let repartition_hash = repartition_exec_hash(filter); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash, true); let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", - " FilterExec: c@2 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", - " FilterExec: c@2 > 3", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_replace_multiple_input_repartition_with_extra_steps() -> Result<()> { + async fn test_replace_multiple_input_repartition_with_extra_steps( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); - let filter = filter_exec(repartition_hash, &schema); + let filter = filter_exec(repartition_hash); let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); let sort = sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec, true); let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_replace_multiple_input_repartition_with_extra_steps_2() -> Result<()> { + async fn test_replace_multiple_input_repartition_with_extra_steps_2( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); - let filter = filter_exec(repartition_hash, &schema); + let filter = filter_exec(repartition_hash); let coalesce_batches_exec_2 = coalesce_batches_exec(filter); let sort = sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec_2, true); @@ -498,62 +759,157 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_not_replacing_when_no_need_to_preserve_sorting() -> Result<()> { + async fn test_not_replacing_when_no_need_to_preserve_sorting( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); - let filter = filter_exec(repartition_hash, &schema); + let filter = filter_exec(repartition_hash); let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); let physical_plan: Arc = coalesce_partitions_exec(coalesce_batches_exec); - let expected_input = ["CoalescePartitionsExec", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "CoalescePartitionsExec", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["CoalescePartitionsExec", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "CoalescePartitionsExec", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "CoalescePartitionsExec", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + + // Expected bounded results same with and without flag, because there is no executor with ordering requirement + let expected_optimized_bounded = [ + "CoalescePartitionsExec", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; + + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_multiple_replacable_repartitions() -> Result<()> { + async fn test_with_multiple_replacable_repartitions( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); - let filter = filter_exec(repartition_hash, &schema); + let filter = filter_exec(repartition_hash); let coalesce_batches = coalesce_batches_exec(filter); let repartition_hash_2 = repartition_exec_hash(coalesce_batches); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash_2, true); @@ -561,141 +917,341 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_not_replace_with_different_orderings() -> Result<()> { + async fn test_not_replace_with_different_orderings( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let sort = sort_exec( - vec![sort_expr_default("c", &schema)], + vec![sort_expr_default("c", &repartition_hash.schema())], repartition_hash, true, ); - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("c", &schema)], sort); + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("c", &sort.schema())], + sort, + ); + + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; - let expected_input = ["SortPreservingMergeExec: [c@2 ASC]", - " SortExec: expr=[c@2 ASC]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [c@2 ASC]", - " SortExec: expr=[c@2 ASC]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + + // Expected bounded results same with and without flag, because ordering requirement of the executor is different than the existing ordering. + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; + + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_lost_ordering() -> Result<()> { + async fn test_with_lost_ordering( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions, false); - let expected_input = ["SortExec: expr=[a@0 ASC NULLS LAST]", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortExec: expr=[a@0 ASC NULLS LAST]", " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_lost_and_kept_ordering() -> Result<()> { + async fn test_with_lost_and_kept_ordering( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let sort = sort_exec( - vec![sort_expr_default("c", &schema)], + vec![sort_expr_default("c", &coalesce_partitions.schema())], coalesce_partitions, false, ); let repartition_rr2 = repartition_exec_round_robin(sort); let repartition_hash2 = repartition_exec_hash(repartition_rr2); - let filter = filter_exec(repartition_hash2, &schema); - let sort2 = sort_exec(vec![sort_expr_default("c", &schema)], filter, true); + let filter = filter_exec(repartition_hash2); + let sort2 = + sort_exec(vec![sort_expr_default("c", &filter.schema())], filter, true); - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("c", &schema)], sort2); + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("c", &sort2.schema())], + sort2, + ); - let expected_input = [ - "SortPreservingMergeExec: [c@2 ASC]", - " SortExec: expr=[c@2 ASC]", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@2 ASC]", + " SortExec: expr=[c@1 ASC]", " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[c@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - let expected_optimized = [ - "SortPreservingMergeExec: [c@2 ASC]", - " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[c@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[c@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@2 ASC]", + " SortExec: expr=[c@1 ASC]", " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_multiple_child_trees() -> Result<()> { + async fn test_with_multiple_child_trees( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema()?; let left_sort_exprs = vec![sort_expr("a", &schema)]; - let left_source = csv_exec_sorted(&schema, left_sort_exprs, true); + let left_source = if source_unbounded { + stream_exec_ordered(&schema, left_sort_exprs) + } else { + csv_exec_sorted(&schema, left_sort_exprs) + }; let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); let left_coalesce_partitions = Arc::new(CoalesceBatchesExec::new(left_repartition_hash, 4096)); let right_sort_exprs = vec![sort_expr("a", &schema)]; - let right_source = csv_exec_sorted(&schema, right_sort_exprs, true); + let right_source = if source_unbounded { + stream_exec_ordered(&schema, right_sort_exprs) + } else { + csv_exec_sorted(&schema, right_sort_exprs) + }; let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); let right_coalesce_partitions = @@ -703,63 +1259,86 @@ mod tests { let hash_join_exec = hash_join_exec(left_coalesce_partitions, right_coalesce_partitions); - let sort = sort_exec(vec![sort_expr_default("a", &schema)], hash_join_exec, true); + let sort = sort_exec( + vec![sort_expr_default("a", &hash_join_exec.schema())], + hash_join_exec, + true, + ); - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("a", &schema)], sort); + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("a", &sort.schema())], + sort, + ); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - - let expected_optimized = [ + let expected_input_bounded = [ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); - Ok(()) - } - #[tokio::test] - async fn test_with_bounded_input() -> Result<()> { - let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, false); - let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); - - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", + ]; - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + // Expected bounded results same with and without flag, because ordering get lost during intermediate executor anyway. Hence no need to preserve + // existing ordering. + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; + + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, + physical_plan, + source_unbounded + ); Ok(()) } @@ -819,24 +1398,23 @@ mod tests { } fn repartition_exec_hash(input: Arc) -> Arc { + let input_schema = input.schema(); Arc::new( RepartitionExec::try_new( input, - Partitioning::Hash(vec![Arc::new(Column::new("c1", 0))], 8), + Partitioning::Hash(vec![col("c", &input_schema).unwrap()], 8), ) .unwrap(), ) } - fn filter_exec( - input: Arc, - schema: &SchemaRef, - ) -> Arc { + fn filter_exec(input: Arc) -> Arc { + let input_schema = input.schema(); let predicate = expressions::binary( - col("c", schema).unwrap(), + col("c", &input_schema).unwrap(), Operator::Gt, expressions::lit(3i32), - schema, + &input_schema, ) .unwrap(); Arc::new(FilterExec::try_new(predicate, input).unwrap()) @@ -854,11 +1432,15 @@ mod tests { left: Arc, right: Arc, ) -> Arc { + let left_on = col("c", &left.schema()).unwrap(); + let right_on = col("c", &right.schema()).unwrap(); + let left_col = left_on.as_any().downcast_ref::().unwrap(); + let right_col = right_on.as_any().downcast_ref::().unwrap(); Arc::new( HashJoinExec::try_new( left, right, - vec![(Column::new("c", 1), Column::new("c", 1))], + vec![(left_col.clone(), right_col.clone())], None, &JoinType::Inner, PartitionMode::Partitioned, @@ -878,12 +1460,33 @@ mod tests { Ok(schema) } + // creates a stream exec source for the test purposes + fn stream_exec_ordered( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, + ) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + let projection: Vec = vec![0, 2, 3]; + + Arc::new( + StreamingTableExec::try_new( + schema.clone(), + vec![Arc::new(TestStreamPartition { + schema: schema.clone(), + }) as _], + Some(&projection), + vec![sort_exprs], + true, + ) + .unwrap(), + ) + } + // creates a csv exec source for the test purposes // projection and has_header parameters are given static due to testing needs fn csv_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, - infinite_source: bool, ) -> Arc { let sort_exprs = sort_exprs.into_iter().collect(); let projection: Vec = vec![0, 2, 3]; @@ -896,12 +1499,11 @@ mod tests { "file_path".to_string(), 100, )]], - statistics: Statistics::default(), + statistics: Statistics::new_unknown(schema), projection: Some(projection), limit: None, table_partition_cols: vec![], output_ordering: vec![sort_exprs], - infinite_source, }, true, 0, @@ -910,11 +1512,4 @@ mod tests { FileCompressionType::UNCOMPRESSED, )) } - - // Util function to get string representation of a physical plan - fn get_plan_string(plan: &Arc) -> Vec { - let formatted = displayable(plan.as_ref()).indent(true).to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - actual.iter().map(|elem| elem.to_string()).collect() - } } diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index 629011cb0faa..f0a8c8cfd3cb 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -15,29 +15,27 @@ // specific language governing permissions and limitations // under the License. +use std::borrow::Cow; use std::sync::Arc; use crate::physical_optimizer::utils::{ add_sort_above, is_limit, is_sort_preserving_merge, is_union, is_window, }; use crate::physical_plan::filter::FilterExec; -use crate::physical_plan::joins::utils::{calculate_join_output_ordering, JoinSide}; +use crate::physical_plan::joins::utils::calculate_join_output_ordering; use crate::physical_plan::joins::{HashJoinExec, SortMergeJoinExec}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{plan_err, DataFusionError, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::{ - ordering_satisfy, ordering_satisfy_requirement, requirements_compatible, +use datafusion_physical_expr::{ + LexRequirementRef, PhysicalSortExpr, PhysicalSortRequirement, }; -use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; - -use itertools::izip; /// This is a "data class" we use within the [`EnforceSorting`] rule to push /// down [`SortExec`] in the plan. In some cases, we can reduce the total @@ -50,164 +48,142 @@ pub(crate) struct SortPushDown { pub plan: Arc, /// Parent required sort ordering required_ordering: Option>, - /// The adjusted request sort ordering to children. - /// By default they are the same as the plan's required input ordering, but can be adjusted based on parent required sort ordering properties. - adjusted_request_ordering: Vec>>, + children_nodes: Vec, } impl SortPushDown { - pub fn init(plan: Arc) -> Self { - let request_ordering = plan.required_input_ordering(); - SortPushDown { + /// Creates an empty tree with empty `required_ordering`'s. + pub fn new(plan: Arc) -> Self { + let children = plan.children(); + Self { plan, required_ordering: None, - adjusted_request_ordering: request_ordering, + children_nodes: children.into_iter().map(Self::new).collect(), } } - pub fn children(&self) -> Vec { - izip!( - self.plan.children().into_iter(), - self.adjusted_request_ordering.clone().into_iter(), - ) - .map(|(child, from_parent)| { - let child_request_ordering = child.required_input_ordering(); - SortPushDown { - plan: child, - required_ordering: from_parent, - adjusted_request_ordering: child_request_ordering, - } - }) - .collect() + /// Assigns the ordering requirement of the root node to the its children. + pub fn assign_initial_requirements(&mut self) { + let reqs = self.plan.required_input_ordering(); + for (child, requirement) in self.children_nodes.iter_mut().zip(reqs) { + child.required_ordering = requirement; + } } } impl TreeNode for SortPushDown { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if !children.is_empty() { - let children_plans = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .map(|r| r.map(|s| s.plan)) - .collect::>>()?; - - match with_new_children_if_necessary(self.plan, children_plans)? { - Transformed::Yes(plan) | Transformed::No(plan) => { - self.plan = plan; - } - } - }; + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); + } Ok(self) } } pub(crate) fn pushdown_sorts( - requirements: SortPushDown, + mut requirements: SortPushDown, ) -> Result> { let plan = &requirements.plan; - let parent_required = requirements.required_ordering.as_deref(); - const ERR_MSG: &str = "Expects parent requirement to contain something"; - let err = || DataFusionError::Plan(ERR_MSG.to_string()); + let parent_required = requirements.required_ordering.as_deref().unwrap_or(&[]); + if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let mut new_plan = plan.clone(); - if !ordering_satisfy_requirement( - plan.output_ordering(), - parent_required, - || plan.equivalence_properties(), - || plan.ordering_equivalence_properties(), - ) { + if !plan + .equivalence_properties() + .ordering_satisfy_requirement(parent_required) + { // If the current plan is a SortExec, modify it to satisfy parent requirements: - let parent_required_expr = PhysicalSortRequirement::to_sort_exprs( - parent_required.ok_or_else(err)?.iter().cloned(), - ); - new_plan = sort_exec.input().clone(); - add_sort_above(&mut new_plan, parent_required_expr, sort_exec.fetch())?; + let mut new_plan = sort_exec.input().clone(); + add_sort_above(&mut new_plan, parent_required, sort_exec.fetch()); + requirements.plan = new_plan; }; - let required_ordering = new_plan + + let required_ordering = requirements + .plan .output_ordering() - .map(PhysicalSortRequirement::from_sort_exprs); + .map(PhysicalSortRequirement::from_sort_exprs) + .unwrap_or_default(); // Since new_plan is a SortExec, we can safely get the 0th index. - let child = &new_plan.children()[0]; + let mut child = requirements.children_nodes.swap_remove(0); if let Some(adjusted) = - pushdown_requirement_to_children(child, required_ordering.as_deref())? + pushdown_requirement_to_children(&child.plan, &required_ordering)? { + for (c, o) in child.children_nodes.iter_mut().zip(adjusted) { + c.required_ordering = o; + } // Can push down requirements - Ok(Transformed::Yes(SortPushDown { - plan: child.clone(), - required_ordering: None, - adjusted_request_ordering: adjusted, - })) + child.required_ordering = None; + Ok(Transformed::Yes(child)) } else { // Can not push down requirements - Ok(Transformed::Yes(SortPushDown::init(new_plan))) + let mut empty_node = SortPushDown::new(requirements.plan); + empty_node.assign_initial_requirements(); + Ok(Transformed::Yes(empty_node)) } } else { // Executors other than SortExec - if ordering_satisfy_requirement( - plan.output_ordering(), - parent_required, - || plan.equivalence_properties(), - || plan.ordering_equivalence_properties(), - ) { + if plan + .equivalence_properties() + .ordering_satisfy_requirement(parent_required) + { // Satisfies parent requirements, immediately return. - return Ok(Transformed::Yes(SortPushDown { - required_ordering: None, - ..requirements - })); + let reqs = requirements.plan.required_input_ordering(); + for (child, order) in requirements.children_nodes.iter_mut().zip(reqs) { + child.required_ordering = order; + } + return Ok(Transformed::Yes(requirements)); } // Can not satisfy the parent requirements, check whether the requirements can be pushed down: if let Some(adjusted) = pushdown_requirement_to_children(plan, parent_required)? { - Ok(Transformed::Yes(SortPushDown { - plan: plan.clone(), - required_ordering: None, - adjusted_request_ordering: adjusted, - })) + for (c, o) in requirements.children_nodes.iter_mut().zip(adjusted) { + c.required_ordering = o; + } + requirements.required_ordering = None; + Ok(Transformed::Yes(requirements)) } else { // Can not push down requirements, add new SortExec: - let parent_required_expr = PhysicalSortRequirement::to_sort_exprs( - parent_required.ok_or_else(err)?.iter().cloned(), - ); - let mut new_plan = plan.clone(); - add_sort_above(&mut new_plan, parent_required_expr, None)?; - Ok(Transformed::Yes(SortPushDown::init(new_plan))) + let mut new_plan = requirements.plan; + add_sort_above(&mut new_plan, parent_required, None); + let mut new_empty = SortPushDown::new(new_plan); + new_empty.assign_initial_requirements(); + // Can not push down requirements + Ok(Transformed::Yes(new_empty)) } } } fn pushdown_requirement_to_children( plan: &Arc, - parent_required: Option<&[PhysicalSortRequirement]>, + parent_required: LexRequirementRef, ) -> Result>>>> { - const ERR_MSG: &str = "Expects parent requirement to contain something"; - let err = || DataFusionError::Plan(ERR_MSG.to_string()); let maintains_input_order = plan.maintains_input_order(); if is_window(plan) { let required_input_ordering = plan.required_input_ordering(); - let request_child = required_input_ordering[0].as_deref(); - let child_plan = plan.children()[0].clone(); + let request_child = required_input_ordering[0].as_deref().unwrap_or(&[]); + let child_plan = plan.children().swap_remove(0); match determine_children_requirement(parent_required, request_child, child_plan) { RequirementsCompatibility::Satisfy => { - Ok(Some(vec![request_child.map(|r| r.to_vec())])) + let req = if request_child.is_empty() { + None + } else { + Some(request_child.to_vec()) + }; + Ok(Some(vec![req])) } RequirementsCompatibility::Compatible(adjusted) => Ok(Some(vec![adjusted])), RequirementsCompatibility::NonCompatible => Ok(None), @@ -215,16 +191,17 @@ fn pushdown_requirement_to_children( } else if is_union(plan) { // UnionExec does not have real sort requirements for its input. Here we change the adjusted_request_ordering to UnionExec's output ordering and // propagate the sort requirements down to correct the unnecessary descendant SortExec under the UnionExec - Ok(Some(vec![ - parent_required.map(|elem| elem.to_vec()); - plan.children().len() - ])) + let req = if parent_required.is_empty() { + None + } else { + Some(parent_required.to_vec()) + }; + Ok(Some(vec![req; plan.children().len()])) } else if let Some(smj) = plan.as_any().downcast_ref::() { // If the current plan is SortMergeJoinExec let left_columns_len = smj.left().schema().fields().len(); - let parent_required_expr = PhysicalSortRequirement::to_sort_exprs( - parent_required.ok_or_else(err)?.iter().cloned(), - ); + let parent_required_expr = + PhysicalSortRequirement::to_sort_exprs(parent_required.iter().cloned()); let expr_source_side = expr_source_sides(&parent_required_expr, smj.join_type(), left_columns_len); match expr_source_side { @@ -238,10 +215,9 @@ fn pushdown_requirement_to_children( let right_offset = smj.schema().fields.len() - smj.right().schema().fields.len(); let new_right_required = - shift_right_required(parent_required.ok_or_else(err)?, right_offset)?; - let new_right_required_expr = PhysicalSortRequirement::to_sort_exprs( - new_right_required.iter().cloned(), - ); + shift_right_required(parent_required, right_offset)?; + let new_right_required_expr = + PhysicalSortRequirement::to_sort_exprs(new_right_required); try_pushdown_requirements_to_join( smj, parent_required, @@ -262,64 +238,71 @@ fn pushdown_requirement_to_children( || plan.as_any().is::() || is_limit(plan) || plan.as_any().is::() - // Do not push-down through SortPreservingMergeExec when - // ordering requirement invalidates requirement of sort preserving merge exec. - || (is_sort_preserving_merge(plan) && !ordering_satisfy( - parent_required - .map(|req| PhysicalSortRequirement::to_sort_exprs(req.to_vec())) - .as_deref(), - plan.output_ordering(), - || plan.equivalence_properties(), - || plan.ordering_equivalence_properties(), - ) - ) { // If the current plan is a leaf node or can not maintain any of the input ordering, can not pushed down requirements. // For RepartitionExec, we always choose to not push down the sort requirements even the RepartitionExec(input_partition=1) could maintain input ordering. // Pushing down is not beneficial Ok(None) + } else if is_sort_preserving_merge(plan) { + let new_ordering = + PhysicalSortRequirement::to_sort_exprs(parent_required.to_vec()); + let mut spm_eqs = plan.equivalence_properties(); + // Sort preserving merge will have new ordering, one requirement above is pushed down to its below. + spm_eqs = spm_eqs.with_reorder(new_ordering); + // Do not push-down through SortPreservingMergeExec when + // ordering requirement invalidates requirement of sort preserving merge exec. + if !spm_eqs.ordering_satisfy(plan.output_ordering().unwrap_or(&[])) { + Ok(None) + } else { + // Can push-down through SortPreservingMergeExec, because parent requirement is finer + // than SortPreservingMergeExec output ordering. + let req = if parent_required.is_empty() { + None + } else { + Some(parent_required.to_vec()) + }; + Ok(Some(vec![req])) + } } else { Ok(Some( maintains_input_order - .iter() + .into_iter() .map(|flag| { - if *flag { - parent_required.map(|elem| elem.to_vec()) + if flag && !parent_required.is_empty() { + Some(parent_required.to_vec()) } else { None } }) - .collect::>(), + .collect(), )) } // TODO: Add support for Projection push down } -/// Determine the children requirements -/// If the children requirements are more specific, do not push down the parent requirements -/// If the the parent requirements are more specific, push down the parent requirements -/// If they are not compatible, need to add Sort. +/// Determine children requirements: +/// - If children requirements are more specific, do not push down parent +/// requirements. +/// - If parent requirements are more specific, push down parent requirements. +/// - If they are not compatible, need to add a sort. fn determine_children_requirement( - parent_required: Option<&[PhysicalSortRequirement]>, - request_child: Option<&[PhysicalSortRequirement]>, + parent_required: LexRequirementRef, + request_child: LexRequirementRef, child_plan: Arc, ) -> RequirementsCompatibility { - if requirements_compatible( - request_child, - parent_required, - || child_plan.ordering_equivalence_properties(), - || child_plan.equivalence_properties(), - ) { - // request child requirements are more specific, no need to push down the parent requirements + if child_plan + .equivalence_properties() + .requirements_compatible(request_child, parent_required) + { + // Child requirements are more specific, no need to push down. RequirementsCompatibility::Satisfy - } else if requirements_compatible( - parent_required, - request_child, - || child_plan.ordering_equivalence_properties(), - || child_plan.equivalence_properties(), - ) { - // parent requirements are more specific, adjust the request child requirements and push down the new requirements - let adjusted = parent_required.map(|r| r.to_vec()); + } else if child_plan + .equivalence_properties() + .requirements_compatible(parent_required, request_child) + { + // Parent requirements are more specific, adjust child's requirements + // and push down the new requirements: + let adjusted = (!parent_required.is_empty()).then(|| parent_required.to_vec()); RequirementsCompatibility::Compatible(adjusted) } else { RequirementsCompatibility::NonCompatible @@ -327,7 +310,7 @@ fn determine_children_requirement( } fn try_pushdown_requirements_to_join( smj: &SortMergeJoinExec, - parent_required: Option<&[PhysicalSortRequirement]>, + parent_required: LexRequirementRef, sort_expr: Vec, push_side: JoinSide, ) -> Result>>>> { @@ -337,32 +320,33 @@ fn try_pushdown_requirements_to_join( JoinSide::Left => (sort_expr.as_slice(), right_ordering), JoinSide::Right => (left_ordering, sort_expr.as_slice()), }; + let join_type = smj.join_type(); + let probe_side = SortMergeJoinExec::probe_side(&join_type); let new_output_ordering = calculate_join_output_ordering( new_left_ordering, new_right_ordering, - smj.join_type(), + join_type, smj.on(), smj.left().schema().fields.len(), &smj.maintains_input_order(), - Some(SortMergeJoinExec::probe_side(&smj.join_type())), - )?; - Ok(ordering_satisfy_requirement( - new_output_ordering.as_deref(), - parent_required, - || smj.equivalence_properties(), - || smj.ordering_equivalence_properties(), - ) - .then(|| { - let required_input_ordering = smj.required_input_ordering(); + Some(probe_side), + ); + let mut smj_eqs = smj.equivalence_properties(); + // smj will have this ordering when its input changes. + smj_eqs = smj_eqs.with_reorder(new_output_ordering.unwrap_or_default()); + let should_pushdown = smj_eqs.ordering_satisfy_requirement(parent_required); + Ok(should_pushdown.then(|| { + let mut required_input_ordering = smj.required_input_ordering(); let new_req = Some(PhysicalSortRequirement::from_sort_exprs(&sort_expr)); match push_side { JoinSide::Left => { - vec![new_req, required_input_ordering[1].clone()] + required_input_ordering[0] = new_req; } JoinSide::Right => { - vec![required_input_ordering[0].clone(), new_req] + required_input_ordering[1] = new_req; } } + required_input_ordering })) } @@ -415,15 +399,13 @@ fn expr_source_sides( } fn shift_right_required( - parent_required: &[PhysicalSortRequirement], + parent_required: LexRequirementRef, left_columns_len: usize, ) -> Result> { let new_right_required: Vec = parent_required .iter() .filter_map(|r| { - let Some(col) = r.expr.as_any().downcast_ref::() else { - return None; - }; + let col = r.expr.as_any().downcast_ref::()?; if col.index() < left_columns_len { return None; diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index e021cda2c868..debafefe39ab 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -35,16 +35,17 @@ use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::create_window_expr; -use crate::physical_plan::{ExecutionPlan, Partitioning}; +use crate::physical_plan::{ExecutionPlan, InputOrderMode, Partitioning}; use crate::prelude::{CsvReadOptions, SessionContext}; use arrow_schema::{Schema, SchemaRef, SortOptions}; use datafusion_common::{JoinType, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; +use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunctionDefinition}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use crate::datasource::stream::{StreamConfig, StreamTable}; use async_trait::async_trait; async fn register_current_csv( @@ -54,14 +55,19 @@ async fn register_current_csv( ) -> Result<()> { let testdata = crate::test_util::arrow_test_data(); let schema = crate::test_util::aggr_test_schema(); - ctx.register_csv( - table_name, - &format!("{testdata}/csv/aggregate_test_100.csv"), - CsvReadOptions::new() - .schema(&schema) - .mark_infinite(infinite), - ) - .await?; + let path = format!("{testdata}/csv/aggregate_test_100.csv"); + + match infinite { + true => { + let config = StreamConfig::new_file(schema, path.into()); + ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; + } + false => { + ctx.register_csv(table_name, &path, CsvReadOptions::new().schema(&schema)) + .await?; + } + } + Ok(()) } @@ -140,14 +146,13 @@ impl QueryCase { async fn run_case(&self, ctx: SessionContext, error: Option<&String>) -> Result<()> { let dataframe = ctx.sql(self.sql.as_str()).await?; let plan = dataframe.create_physical_plan().await; - if error.is_some() { + if let Some(error) = error { let plan_error = plan.unwrap_err(); - let initial = error.unwrap().to_string(); assert!( - plan_error.to_string().contains(initial.as_str()), + plan_error.to_string().contains(error.as_str()), "plan_error: {:?} doesn't contain message: {:?}", plan_error, - initial.as_str() + error.as_str() ); } else { assert!(plan.is_ok()) @@ -229,7 +234,7 @@ pub fn bounded_window_exec( Arc::new( crate::physical_plan::windows::BoundedWindowAggExec::try_new( vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), "count".to_owned(), &[col(col_name, &schema).unwrap()], &[], @@ -239,9 +244,8 @@ pub fn bounded_window_exec( ) .unwrap()], input.clone(), - input.schema(), vec![], - crate::physical_plan::windows::PartitionSearchMode::Sorted, + InputOrderMode::Sorted, ) .unwrap(), ) @@ -269,12 +273,11 @@ pub fn parquet_exec(schema: &SchemaRef) -> Arc { object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), file_schema: schema.clone(), file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), + statistics: Statistics::new_unknown(schema), projection: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -293,12 +296,11 @@ pub fn parquet_exec_sorted( object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), file_schema: schema.clone(), file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), + statistics: Statistics::new_unknown(schema), projection: None, limit: None, table_partition_cols: vec![], output_ordering: vec![sort_exprs], - infinite_source: false, }, None, None, @@ -325,6 +327,14 @@ pub fn repartition_exec(input: Arc) -> Arc Arc::new(RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10)).unwrap()) } +pub fn spr_repartition_exec(input: Arc) -> Arc { + Arc::new( + RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(), + ) +} + pub fn aggregate_exec(input: Arc) -> Arc { let schema = input.schema(); Arc::new( @@ -333,7 +343,6 @@ pub fn aggregate_exec(input: Arc) -> Arc { PhysicalGroupBy::default(), vec![], vec![], - vec![], input, schema, ) diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs index 572e796a8ba7..dd0261420304 100644 --- a/datafusion/core/src/physical_optimizer/topk_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -73,9 +73,8 @@ impl TopKAggregation { aggr.group_by().clone(), aggr.aggr_expr().to_vec(), aggr.filter_expr().to_vec(), - aggr.order_by_expr().to_vec(), aggr.input().clone(), - aggr.input_schema().clone(), + aggr.input_schema(), ) .expect("Unable to copy Aggregate!") .with_limit(Some(limit)); @@ -118,7 +117,7 @@ impl TopKAggregation { } Ok(Transformed::No(plan)) }; - let child = transform_down_mut(child.clone(), &mut closure).ok()?; + let child = child.clone().transform_down_mut(&mut closure).ok()?; let sort = SortExec::new(sort.expr().to_vec(), child) .with_fetch(sort.fetch()) .with_preserve_partitioning(sort.preserve_partitioning()); @@ -126,17 +125,6 @@ impl TopKAggregation { } } -fn transform_down_mut( - me: Arc, - op: &mut F, -) -> Result> -where - F: FnMut(Arc) -> Result>>, -{ - let after_op = op(me)?.into(); - after_op.map_children(|node| transform_down_mut(node, op)) -} - impl Default for TopKAggregation { fn default() -> Self { Self::new() diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index 21c976e07a15..f8063e969422 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -17,76 +17,32 @@ //! Collection of utility functions that are leveraged by the query optimizer rules -use std::fmt; -use std::fmt::Formatter; use std::sync::Arc; -use crate::error::Result; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; -use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; -use crate::physical_plan::{displayable, ExecutionPlan}; +use crate::physical_plan::ExecutionPlan; -use datafusion_physical_expr::utils::ordering_satisfy; -use datafusion_physical_expr::PhysicalSortExpr; - -/// This object implements a tree that we use while keeping track of paths -/// leading to [`SortExec`]s. -#[derive(Debug, Clone)] -pub(crate) struct ExecTree { - /// The `ExecutionPlan` associated with this node - pub plan: Arc, - /// Child index of the plan in its parent - pub idx: usize, - /// Children of the plan that would need updating if we remove leaf executors - pub children: Vec, -} - -impl fmt::Display for ExecTree { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let plan_string = get_plan_string(&self.plan); - write!(f, "\nidx: {:?}", self.idx)?; - write!(f, "\nplan: {:?}", plan_string)?; - for child in self.children.iter() { - write!(f, "\nexec_tree:{}", child)?; - } - writeln!(f) - } -} - -impl ExecTree { - /// Create new Exec tree - pub fn new( - plan: Arc, - idx: usize, - children: Vec, - ) -> Self { - ExecTree { - plan, - idx, - children, - } - } -} +use datafusion_physical_expr::{LexRequirementRef, PhysicalSortRequirement}; +use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; /// This utility function adds a `SortExec` above an operator according to the /// given ordering requirements while preserving the original partitioning. pub fn add_sort_above( node: &mut Arc, - sort_expr: Vec, + sort_requirement: LexRequirementRef, fetch: Option, -) -> Result<()> { +) { // If the ordering requirement is already satisfied, do not add a sort. - if !ordering_satisfy( - node.output_ordering(), - Some(&sort_expr), - || node.equivalence_properties(), - || node.ordering_equivalence_properties(), - ) { + if !node + .equivalence_properties() + .ordering_satisfy_requirement(sort_requirement) + { + let sort_expr = PhysicalSortRequirement::to_sort_exprs(sort_requirement.to_vec()); let new_sort = SortExec::new(sort_expr, node.clone()).with_fetch(fetch); *node = Arc::new(if node.output_partitioning().partition_count() > 1 { @@ -95,7 +51,6 @@ pub fn add_sort_above( new_sort }) as _ } - Ok(()) } /// Checks whether the given operator is a limit; @@ -134,10 +89,3 @@ pub fn is_union(plan: &Arc) -> bool { pub fn is_repartition(plan: &Arc) -> bool { plan.as_any().is::() } - -/// Utility function yielding a string representation of the given [`ExecutionPlan`]. -pub fn get_plan_string(plan: &Arc) -> Vec { - let formatted = displayable(plan.as_ref()).indent(true).to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - actual.iter().map(|elem| elem.to_string()).collect() -} diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 2328ffce235d..d696c55a8c13 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -17,16 +17,21 @@ //! Planner for [`LogicalPlan`] to [`ExecutionPlan`] +use std::collections::HashMap; +use std::fmt::Write; +use std::sync::Arc; + use crate::datasource::file_format::arrow::ArrowFormat; use crate::datasource::file_format::avro::AvroFormat; use crate::datasource::file_format::csv::CsvFormat; use crate::datasource::file_format::json::JsonFormat; +#[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; -use crate::datasource::file_format::write::FileWriterMode; use crate::datasource::file_format::FileFormat; use crate::datasource::listing::ListingTableUrl; use crate::datasource::physical_plan::FileSinkConfig; use crate::datasource::source_as_provider; +use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_expr::utils::generate_sort_key; use crate::logical_expr::{ @@ -37,67 +42,63 @@ use crate::logical_expr::{ CrossJoin, Expr, LogicalPlan, Partitioning as LogicalPartitioning, PlanType, Repartition, Union, UserDefinedLogicalNode, }; -use crate::physical_plan::memory::MemoryExec; -use arrow_array::builder::StringBuilder; -use arrow_array::RecordBatch; -use datafusion_common::display::ToStringifiedPlan; -use datafusion_common::file_options::FileTypeWriterOptions; -use datafusion_common::FileType; -use datafusion_expr::dml::{CopyOptions, CopyTo}; - use crate::logical_expr::{Limit, Values}; use crate::physical_expr::create_physical_expr; use crate::physical_optimizer::optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::analyze::AnalyzeExec; +use crate::physical_plan::empty::EmptyExec; use crate::physical_plan::explain::ExplainExec; use crate::physical_plan::expressions::{Column, PhysicalSortExpr}; use crate::physical_plan::filter::FilterExec; -use crate::physical_plan::joins::HashJoinExec; -use crate::physical_plan::joins::SortMergeJoinExec; -use crate::physical_plan::joins::{CrossJoinExec, NestedLoopJoinExec}; +use crate::physical_plan::joins::utils as join_utils; +use crate::physical_plan::joins::{ + CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, +}; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use crate::physical_plan::memory::MemoryExec; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::union::UnionExec; use crate::physical_plan::unnest::UnnestExec; -use crate::physical_plan::windows::{ - BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, -}; +use crate::physical_plan::values::ValuesExec; +use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{ - aggregates, empty::EmptyExec, joins::PartitionMode, udaf, union::UnionExec, - values::ValuesExec, windows, -}; -use crate::physical_plan::{joins::utils as join_utils, Partitioning}; -use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr}; -use crate::{ - error::{DataFusionError, Result}, - physical_plan::displayable, + aggregates, displayable, udaf, windows, AggregateExpr, ExecutionPlan, InputOrderMode, + Partitioning, PhysicalExpr, WindowExpr, }; + use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; -use async_trait::async_trait; +use arrow_array::builder::StringBuilder; +use arrow_array::RecordBatch; +use datafusion_common::display::ToStringifiedPlan; +use datafusion_common::file_options::FileTypeWriterOptions; use datafusion_common::{ - exec_err, internal_err, not_impl_err, plan_err, DFSchema, ScalarValue, + exec_err, internal_err, not_impl_err, plan_err, DFSchema, FileType, ScalarValue, }; +use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr::{ - self, AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Cast, - GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, ScalarUDF, TryCast, + self, AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, + Cast, GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, WindowFunction, }; -use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols}; +use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; -use datafusion_expr::{DescribeTable, DmlStatement, StringifiedPlan, WriteOp}; -use datafusion_expr::{WindowFrame, WindowFrameBound}; +use datafusion_expr::{ + DescribeTable, DmlStatement, ScalarFunctionDefinition, StringifiedPlan, WindowFrame, + WindowFrameBound, WriteOp, +}; use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_sql::utils::window_expr_common_partition_keys; + +use async_trait::async_trait; use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; use log::{debug, trace}; -use std::collections::HashMap; -use std::fmt::Write; -use std::sync::Arc; fn create_function_physical_name( fun: &str, @@ -216,40 +217,49 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Ok(name) } - Expr::ScalarFunction(func) => { - create_function_physical_name(&func.fun.to_string(), false, &func.args) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - create_function_physical_name(&fun.name, false, args) + Expr::ScalarFunction(fun) => { + // function should be resolved during `AnalyzerRule`s + if let ScalarFunctionDefinition::Name(_) = fun.func_def { + return internal_err!("Function `Expr` with name should be resolved."); + } + + create_function_physical_name(fun.name(), false, &fun.args) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { create_function_physical_name(&fun.to_string(), false, args) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, - .. - }) => create_function_physical_name(&fun.to_string(), *distinct, args), - Expr::AggregateUDF(AggregateUDF { - fun, - args, filter, order_by, - }) => { - // TODO: Add support for filter and order by in AggregateUDF - if filter.is_some() { - return exec_err!("aggregate expression with filter is not supported"); + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(..) => { + create_function_physical_name(func_def.name(), *distinct, args) } - if order_by.is_some() { - return exec_err!("aggregate expression with order_by is not supported"); + AggregateFunctionDefinition::UDF(fun) => { + // TODO: Add support for filter and order by in AggregateUDF + if filter.is_some() { + return exec_err!( + "aggregate expression with filter is not supported" + ); + } + if order_by.is_some() { + return exec_err!( + "aggregate expression with order_by is not supported" + ); + } + let names = args + .iter() + .map(|e| create_physical_name(e, false)) + .collect::>>()?; + Ok(format!("{}({})", fun.name(), names.join(","))) } - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_physical_name(e, false)?); + AggregateFunctionDefinition::Name(_) => { + internal_err!("Aggregate function `Expr` with name should be resolved.") } - Ok(format!("{}({})", fun.name, names.join(","))) - } + }, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Ok(format!( "ROLLUP ({})", @@ -362,9 +372,8 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Expr::Sort { .. } => { internal_err!("Create physical name does not support sort expression") } - Expr::Wildcard => internal_err!("Create physical name does not support wildcard"), - Expr::QualifiedWildcard { .. } => { - internal_err!("Create physical name does not support qualified wildcard") + Expr::Wildcard { .. } => { + internal_err!("Create physical name does not support wildcard") } Expr::Placeholder(_) => { internal_err!("Create physical name does not support placeholder") @@ -552,8 +561,7 @@ impl DefaultPhysicalPlanner { // doesn't know (nor should care) how the relation was // referred to in the query let filters = unnormalize_cols(filters.iter().cloned()); - let unaliased: Vec = filters.into_iter().map(unalias).collect(); - source.scan(session_state, projection.as_ref(), &unaliased, *fetch).await + source.scan(session_state, projection.as_ref(), &filters, *fetch).await } LogicalPlan::Copy(CopyTo{ input, @@ -563,11 +571,7 @@ impl DefaultPhysicalPlanner { copy_options, }) => { let input_exec = self.create_initial_plan(input, session_state).await?; - - // TODO: make this behavior configurable via options (should copy to create path/file as needed?) - // TODO: add additional configurable options for if existing files should be overwritten or - // appended to - let parsed_url = ListingTableUrl::parse_create_local_if_not_exists(output_url, !*single_file_output)?; + let parsed_url = ListingTableUrl::parse(output_url)?; let object_store_url = parsed_url.object_store(); let schema: Schema = (**input.schema()).clone().into(); @@ -589,8 +593,6 @@ impl DefaultPhysicalPlanner { file_groups: vec![], output_schema: Arc::new(schema), table_partition_cols: vec![], - unbounded_input: false, - writer_mode: FileWriterMode::PutMultipart, single_file_output: *single_file_output, overwrite: false, file_type_writer_options @@ -598,13 +600,14 @@ impl DefaultPhysicalPlanner { let sink_format: Arc = match file_format { FileType::CSV => Arc::new(CsvFormat::default()), + #[cfg(feature = "parquet")] FileType::PARQUET => Arc::new(ParquetFormat::default()), FileType::JSON => Arc::new(JsonFormat::default()), FileType::AVRO => Arc::new(AvroFormat {} ), FileType::ARROW => Arc::new(ArrowFormat {}), }; - sink_format.create_writer_physical_plan(input_exec, session_state, config).await + sink_format.create_writer_physical_plan(input_exec, session_state, config, None).await } LogicalPlan::Dml(DmlStatement { table_name, @@ -751,15 +754,13 @@ impl DefaultPhysicalPlanner { Arc::new(BoundedWindowAggExec::try_new( window_expr, input_exec, - physical_input_schema, physical_partition_keys, - PartitionSearchMode::Sorted, + InputOrderMode::Sorted, )?) } else { Arc::new(WindowAggExec::try_new( window_expr, input_exec, - physical_input_schema, physical_partition_keys, )?) }) @@ -793,14 +794,13 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; - let (aggregates, filters, order_bys) : (Vec<_>, Vec<_>, Vec<_>) = multiunzip(agg_filter); + let (aggregates, filters, _order_bys) : (Vec<_>, Vec<_>, Vec<_>) = multiunzip(agg_filter); let initial_aggr = Arc::new(AggregateExec::try_new( AggregateMode::Partial, groups.clone(), aggregates.clone(), filters.clone(), - order_bys, input_exec, physical_input_schema.clone(), )?); @@ -818,18 +818,14 @@ impl DefaultPhysicalPlanner { // To reflect such changes to subsequent stages, use the updated // `AggregateExpr`/`PhysicalSortExpr` objects. let updated_aggregates = initial_aggr.aggr_expr().to_vec(); - let updated_order_bys = initial_aggr.order_by_expr().to_vec(); - let (initial_aggr, next_partition_mode): ( - Arc, - AggregateMode, - ) = if can_repartition { + let next_partition_mode = if can_repartition { // construct a second aggregation with 'AggregateMode::FinalPartitioned' - (initial_aggr, AggregateMode::FinalPartitioned) + AggregateMode::FinalPartitioned } else { // construct a second aggregation, keeping the final column name equal to the // first aggregation and the expressions corresponding to the respective aggregate - (initial_aggr, AggregateMode::Final) + AggregateMode::Final }; let final_grouping_set = PhysicalGroupBy::new_single( @@ -845,7 +841,6 @@ impl DefaultPhysicalPlanner { final_grouping_set, updated_aggregates, filters, - updated_order_bys, initial_aggr, physical_input_schema.clone(), )?)) @@ -913,19 +908,14 @@ impl DefaultPhysicalPlanner { &input_schema, session_state, )?; - Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?)) + let selectivity = session_state.config().options().optimizer.default_filter_selectivity; + let filter = FilterExec::try_new(runtime_expr, physical_input)?; + Ok(Arc::new(filter.with_default_selectivity(selectivity)?)) } - LogicalPlan::Union(Union { inputs, schema }) => { + LogicalPlan::Union(Union { inputs, .. }) => { let physical_plans = self.create_initial_plan_multi(inputs.iter().map(|lp| lp.as_ref()), session_state).await?; - if schema.fields().len() < physical_plans[0].schema().fields().len() { - // `schema` could be a subset of the child schema. For example - // for query "select count(*) from (select a from t union all select a from t)" - // `schema` is empty but child schema contains one field `a`. - Ok(Arc::new(UnionExec::try_new_with_schema(physical_plans, schema.clone())?)) - } else { - Ok(Arc::new(UnionExec::new(physical_plans))) - } + Ok(Arc::new(UnionExec::new(physical_plans))) } LogicalPlan::Repartition(Repartition { input, @@ -1199,10 +1189,15 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Subquery(_) => todo!(), LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row, + produce_one_row: false, schema, }) => Ok(Arc::new(EmptyExec::new( - *produce_one_row, + SchemaRef::new(schema.as_ref().to_owned().into()), + ))), + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema, + }) => Ok(Arc::new(PlaceholderRowExec::new( SchemaRef::new(schema.as_ref().to_owned().into()), ))), LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { @@ -1251,10 +1246,10 @@ impl DefaultPhysicalPlanner { "Unsupported logical plan: Prepare" ) } - LogicalPlan::Dml(_) => { + LogicalPlan::Dml(dml) => { // DataFusion is a read-only query engine, but also a library, so consumers may implement this not_impl_err!( - "Unsupported logical plan: Dml" + "Unsupported logical plan: Dml({0})", dml.op ) } LogicalPlan::Statement(statement) => { @@ -1708,7 +1703,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ) -> Result { match e { Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, filter, @@ -1749,63 +1744,35 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ), None => None, }; - let ordering_reqs = order_by.clone().unwrap_or(vec![]); - let agg_expr = aggregates::create_aggregate_expr( - fun, - *distinct, - &args, - &ordering_reqs, - physical_input_schema, - name, - )?; - Ok((agg_expr, filter, order_by)) - } - Expr::AggregateUDF(AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let args = args - .iter() - .map(|e| { - create_physical_expr( - e, - logical_input_schema, + let (agg_expr, filter, order_by) = match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let ordering_reqs = order_by.clone().unwrap_or(vec![]); + let agg_expr = aggregates::create_aggregate_expr( + fun, + *distinct, + &args, + &ordering_reqs, physical_input_schema, - execution_props, + name, + )?; + (agg_expr, filter, order_by) + } + AggregateFunctionDefinition::UDF(fun) => { + let agg_expr = udaf::create_aggregate_expr( + fun, + &args, + physical_input_schema, + name, + ); + (agg_expr?, filter, order_by) + } + AggregateFunctionDefinition::Name(_) => { + return internal_err!( + "Aggregate function name should have been resolved" ) - }) - .collect::>>()?; - - let filter = match filter { - Some(e) => Some(create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - execution_props, - )?), - None => None, - }; - let order_by = match order_by { - Some(e) => Some( - e.iter() - .map(|expr| { - create_physical_sort_expr( - expr, - logical_input_schema, - physical_input_schema, - execution_props, - ) - }) - .collect::>>()?, - ), - None => None, + } }; - - let agg_expr = - udaf::create_aggregate_expr(fun, &args, physical_input_schema, name); - Ok((agg_expr?, filter, order_by)) + Ok((agg_expr, filter, order_by)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), } @@ -1893,13 +1860,26 @@ impl DefaultPhysicalPlanner { .await { Ok(input) => { + // This plan will includes statistics if show_statistics is on stringified_plans.push( displayable(input.as_ref()) .set_show_statistics(config.show_statistics) .to_stringified(e.verbose, InitialPhysicalPlan), ); - match self.optimize_internal( + // If the show_statisitcs is off, add another line to show statsitics in the case of explain verbose + if e.verbose && !config.show_statistics { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(true) + .to_stringified( + e.verbose, + InitialPhysicalPlanWithStats, + ), + ); + } + + let optimized_plan = self.optimize_internal( input, session_state, |plan, optimizer| { @@ -1911,12 +1891,28 @@ impl DefaultPhysicalPlanner { .to_stringified(e.verbose, plan_type), ); }, - ) { - Ok(input) => stringified_plans.push( - displayable(input.as_ref()) - .set_show_statistics(config.show_statistics) - .to_stringified(e.verbose, FinalPhysicalPlan), - ), + ); + match optimized_plan { + Ok(input) => { + // This plan will includes statistics if show_statistics is on + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(config.show_statistics) + .to_stringified(e.verbose, FinalPhysicalPlan), + ); + + // If the show_statisitcs is off, add another line to show statsitics in the case of explain verbose + if e.verbose && !config.show_statistics { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(true) + .to_stringified( + e.verbose, + FinalPhysicalPlanWithStats, + ), + ); + } + } Err(DataFusionError::Context(optimizer_name, e)) => { let plan_type = OptimizedPhysicalPlan { optimizer_name }; stringified_plans @@ -2015,7 +2011,7 @@ impl DefaultPhysicalPlanner { let mut column_names = StringBuilder::new(); let mut data_types = StringBuilder::new(); let mut is_nullables = StringBuilder::new(); - for (_, field) in table_schema.fields().iter().enumerate() { + for field in table_schema.fields() { column_names.append_value(field.name()); // "System supplied type" --> Use debug format of the datatype @@ -2058,9 +2054,7 @@ mod tests { use super::*; use crate::datasource::file_format::options::CsvReadOptions; use crate::datasource::MemTable; - use crate::physical_plan::{ - expressions, DisplayFormatType, Partitioning, Statistics, - }; + use crate::physical_plan::{expressions, DisplayFormatType, Partitioning}; use crate::physical_plan::{DisplayAs, SendableRecordBatchStream}; use crate::physical_planner::PhysicalPlanner; use crate::prelude::{SessionConfig, SessionContext}; @@ -2087,7 +2081,7 @@ mod tests { let runtime = Arc::new(RuntimeEnv::default()); let config = SessionConfig::new().with_target_partitions(4); let config = config.set_bool("datafusion.optimizer.skip_failed_rules", false); - SessionState::with_config_rt(config, runtime) + SessionState::new_with_config_rt(config, runtime) } async fn plan(logical_plan: &LogicalPlan) -> Result> { @@ -2517,6 +2511,27 @@ mod tests { Ok(()) } + #[tokio::test] + async fn aggregate_with_alias() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + ])); + + let logical_plan = scan_empty(None, schema.as_ref(), None)? + .aggregate(vec![col("c1")], vec![sum(col("c2"))])? + .project(vec![col("c1"), sum(col("c2")).alias("total_salary")])? + .build()?; + + let physical_plan = plan(&logical_plan).await?; + assert_eq!("c1", physical_plan.schema().field(0).name().as_str()); + assert_eq!( + "total_salary", + physical_plan.schema().field(1).name().as_str() + ); + Ok(()) + } + #[tokio::test] async fn test_explain() { let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); @@ -2671,10 +2686,6 @@ mod tests { ) -> Result { unimplemented!("NoOpExecutionPlan::execute"); } - - fn statistics(&self) -> Statistics { - unimplemented!("NoOpExecutionPlan::statistics"); - } } // Produces an execution plan where the schema is mismatched from @@ -2755,7 +2766,7 @@ mod tests { digraph { 1[shape=box label="ProjectionExec: expr=[id@0 + 2 as employee.id + Int32(2)]", tooltip=""] - 2[shape=box label="EmptyExec: produce_one_row=false", tooltip=""] + 2[shape=box label="EmptyExec", tooltip=""] 1 -> 2 [arrowhead=none, arrowtail=normal, dir=back] } // End DataFusion GraphViz Plan diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index 3782feca191a..5cd8b3870f81 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -13,7 +13,7 @@ // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations -// under the License.pub}, +// under the License. //! DataFusion "prelude" to simplify importing common types. //! @@ -38,3 +38,8 @@ pub use datafusion_expr::{ logical_plan::{JoinType, Partitioning}, Expr, }; + +pub use std::ops::Not; +pub use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; +pub use std::ops::{BitAnd, BitOr, BitXor}; +pub use std::ops::{Shl, Shr}; diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 903542ca3fad..ed5aa15e291b 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -17,6 +17,13 @@ //! Common unit test utility methods +use std::any::Any; +use std::fs::File; +use std::io::prelude::*; +use std::io::{BufReader, BufWriter}; +use std::path::Path; +use std::sync::Arc; + use crate::datasource::file_format::file_compression_type::{ FileCompressionType, FileTypeExt, }; @@ -29,29 +36,24 @@ use crate::logical_expr::LogicalPlan; use crate::physical_plan::ExecutionPlan; use crate::test::object_store::local_unpartitioned_file; use crate::test_util::{aggr_test_schema, arrow_test_data}; -use array::ArrayRef; -use arrow::array::{self, Array, Decimal128Builder, Int32Array}; + +use arrow::array::{self, Array, ArrayRef, Decimal128Builder, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use datafusion_common::{DataFusionError, FileType, Statistics}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::{Partitioning, PhysicalSortExpr}; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; +use datafusion_physical_plan::{DisplayAs, DisplayFormatType}; + #[cfg(feature = "compression")] use bzip2::write::BzEncoder; #[cfg(feature = "compression")] use bzip2::Compression as BzCompression; -use datafusion_common::FileType; -use datafusion_common::{DataFusionError, Statistics}; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{Partitioning, PhysicalSortExpr}; -use datafusion_physical_plan::{DisplayAs, DisplayFormatType}; #[cfg(feature = "compression")] use flate2::write::GzEncoder; #[cfg(feature = "compression")] use flate2::Compression as GzCompression; -use std::any::Any; -use std::fs::File; -use std::io::prelude::*; -use std::io::{BufReader, BufWriter}; -use std::path::Path; -use std::sync::Arc; #[cfg(feature = "compression")] use xz2::write::XzEncoder; #[cfg(feature = "compression")] @@ -195,14 +197,13 @@ pub fn partitioned_csv_config( ) -> Result { Ok(FileScanConfig { object_store_url: ObjectStoreUrl::local_filesystem(), - file_schema: schema, + file_schema: schema.clone(), file_groups, - statistics: Default::default(), + statistics: Statistics::new_unknown(&schema), projection: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }) } @@ -276,7 +277,6 @@ fn make_decimal() -> RecordBatch { pub fn csv_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, - infinite_source: bool, ) -> Arc { let sort_exprs = sort_exprs.into_iter().collect(); @@ -285,12 +285,11 @@ pub fn csv_exec_sorted( object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), file_schema: schema.clone(), file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), + statistics: Statistics::new_unknown(schema), projection: None, limit: None, table_partition_cols: vec![], output_ordering: vec![sort_exprs], - infinite_source, }, false, 0, @@ -300,6 +299,67 @@ pub fn csv_exec_sorted( )) } +// construct a stream partition for test purposes +pub(crate) struct TestStreamPartition { + pub schema: SchemaRef, +} + +impl PartitionStream for TestStreamPartition { + fn schema(&self) -> &SchemaRef { + &self.schema + } + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + unreachable!() + } +} + +/// Create an unbounded stream exec +pub fn stream_exec_ordered( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + + Arc::new( + StreamingTableExec::try_new( + schema.clone(), + vec![Arc::new(TestStreamPartition { + schema: schema.clone(), + }) as _], + None, + vec![sort_exprs], + true, + ) + .unwrap(), + ) +} + +/// Create a csv exec for tests +pub fn csv_exec_ordered( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("file_path".to_string(), 100)]], + statistics: Statistics::new_unknown(schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![sort_exprs], + }, + true, + 0, + b'"', + None, + FileCompressionType::UNCOMPRESSED, + )) +} + /// A mock execution plan that simply returns the provided statistics #[derive(Debug, Clone)] pub struct StatisticsExec { @@ -308,12 +368,8 @@ pub struct StatisticsExec { } impl StatisticsExec { pub fn new(stats: Statistics, schema: Schema) -> Self { - assert!( - stats - .column_statistics - .as_ref() - .map(|cols| cols.len() == schema.fields().len()) - .unwrap_or(true), + assert_eq!( + stats.column_statistics.len(), schema.fields().len(), "if defined, the column statistics vector length should be the number of fields" ); Self { @@ -378,8 +434,8 @@ impl ExecutionPlan for StatisticsExec { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Statistics { - self.stats.clone() + fn statistics(&self) -> Result { + Ok(self.stats.clone()) } } diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index 425d0724ea4f..d6f324a7f1f9 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. //! Object store implementation used for testing +use crate::execution::context::SessionState; use crate::prelude::SessionContext; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::runtime_env::RuntimeEnv; use futures::FutureExt; use object_store::{memory::InMemory, path::Path, ObjectMeta, ObjectStore}; use std::sync::Arc; @@ -25,11 +28,11 @@ use url::Url; pub fn register_test_store(ctx: &SessionContext, files: &[(&str, u64)]) { let url = Url::parse("test://").unwrap(); ctx.runtime_env() - .register_object_store(&url, make_test_store(files)); + .register_object_store(&url, make_test_store_and_state(files).0); } /// Create a test object store with the provided files -pub fn make_test_store(files: &[(&str, u64)]) -> Arc { +pub fn make_test_store_and_state(files: &[(&str, u64)]) -> (Arc, SessionState) { let memory = InMemory::new(); for (name, size) in files { @@ -40,7 +43,13 @@ pub fn make_test_store(files: &[(&str, u64)]) -> Arc { .unwrap(); } - Arc::new(memory) + ( + Arc::new(memory), + SessionState::new_with_config_rt( + SessionConfig::default(), + Arc::new(RuntimeEnv::default()), + ), + ) } /// Helper method to fetch the file size and date at given path and create a `ObjectMeta` @@ -52,5 +61,6 @@ pub fn local_unpartitioned_file(path: impl AsRef) -> ObjectMeta last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), size: metadata.len() as usize, e_tag: None, + version: None, } } diff --git a/datafusion/core/src/test/variable.rs b/datafusion/core/src/test/variable.rs index a55513841561..38207b42cb7b 100644 --- a/datafusion/core/src/test/variable.rs +++ b/datafusion/core/src/test/variable.rs @@ -37,7 +37,7 @@ impl VarProvider for SystemVar { /// get system variable value fn get_value(&self, var_names: Vec) -> Result { let s = format!("{}-{}", "system-var", var_names.concat()); - Ok(ScalarValue::Utf8(Some(s))) + Ok(ScalarValue::from(s)) } fn get_type(&self, _: &[String]) -> Option { @@ -61,7 +61,7 @@ impl VarProvider for UserDefinedVar { fn get_value(&self, var_names: Vec) -> Result { if var_names[0] != "@integer" { let s = format!("{}-{}", "user-defined-var", var_names.concat()); - Ok(ScalarValue::Utf8(Some(s))) + Ok(ScalarValue::from(s)) } else { Ok(ScalarValue::Int32(Some(41))) } diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index bd52c3eedaa4..282b0f7079ee 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -17,41 +17,48 @@ //! Utility functions to make testing DataFusion based crates easier +#[cfg(feature = "parquet")] pub mod parquet; use std::any::Any; use std::collections::HashMap; +use std::fs::File; +use std::io::Write; use std::path::Path; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use tempfile::TempDir; + +use crate::dataframe::DataFrame; use crate::datasource::provider::TableProviderFactory; use crate::datasource::{empty::EmptyTable, provider_as_source, TableProvider}; use crate::error::Result; use crate::execution::context::{SessionState, TaskContext}; -use crate::execution::options::ReadOptions; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, }; use crate::prelude::{CsvReadOptions, SessionContext}; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use async_trait::async_trait; -use datafusion_common::{Statistics, TableReference}; +use datafusion_common::TableReference; use datafusion_expr::{CreateExternalTable, Expr, TableType}; use datafusion_physical_expr::PhysicalSortExpr; + +use async_trait::async_trait; use futures::Stream; // backwards compatibility -pub use datafusion_common::test_util::{ - arrow_test_data, get_data_dir, parquet_test_data, -}; +#[cfg(feature = "parquet")] +pub use datafusion_common::test_util::parquet_test_data; +pub use datafusion_common::test_util::{arrow_test_data, get_data_dir}; -pub use datafusion_common::assert_batches_eq; -pub use datafusion_common::assert_batches_sorted_eq; +use crate::datasource::stream::{StreamConfig, StreamTable}; +pub use datafusion_common::{assert_batches_eq, assert_batches_sorted_eq}; /// Scan an empty data source, mainly used in tests pub fn scan_empty( @@ -101,6 +108,71 @@ pub fn aggr_test_schema() -> SchemaRef { Arc::new(schema) } +/// Register session context for the aggregate_test_100.csv file +pub async fn register_aggregate_csv( + ctx: &mut SessionContext, + table_name: &str, +) -> Result<()> { + let schema = aggr_test_schema(); + let testdata = arrow_test_data(); + ctx.register_csv( + table_name, + &format!("{testdata}/csv/aggregate_test_100.csv"), + CsvReadOptions::new().schema(schema.as_ref()), + ) + .await?; + Ok(()) +} + +/// Create a table from the aggregate_test_100.csv file with the specified name +pub async fn test_table_with_name(name: &str) -> Result { + let mut ctx = SessionContext::new(); + register_aggregate_csv(&mut ctx, name).await?; + ctx.table(name).await +} + +/// Create a table from the aggregate_test_100.csv file with the name "aggregate_test_100" +pub async fn test_table() -> Result { + test_table_with_name("aggregate_test_100").await +} + +/// Execute SQL and return results +pub async fn plan_and_collect( + ctx: &SessionContext, + sql: &str, +) -> Result> { + ctx.sql(sql).await?.collect().await +} + +/// Generate CSV partitions within the supplied directory +pub fn populate_csv_partitions( + tmp_dir: &TempDir, + partition_count: usize, + file_extension: &str, +) -> Result { + // define schema for data source (csv file) + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::UInt32, false), + Field::new("c2", DataType::UInt64, false), + Field::new("c3", DataType::Boolean, false), + ])); + + // generate a partitioned file + for partition in 0..partition_count { + let filename = format!("partition-{partition}.{file_extension}"); + let file_path = tmp_dir.path().join(filename); + let mut file = File::create(file_path)?; + + // generate some data + for i in 0..=10 { + let data = format!("{},{},{}\n", partition, i, i % 2 == 0); + file.write_all(data.as_bytes())?; + } + } + + Ok(schema) +} + /// TableFactory for tests pub struct TestTableFactory {} @@ -237,10 +309,6 @@ impl ExecutionPlan for UnboundedExec { batch: self.batch.clone(), })) } - - fn statistics(&self) -> Statistics { - Statistics::default() - } } #[derive(Debug)] @@ -274,30 +342,17 @@ impl RecordBatchStream for UnboundedStream { } /// This function creates an unbounded sorted file for testing purposes. -pub async fn register_unbounded_file_with_ordering( +pub fn register_unbounded_file_with_ordering( ctx: &SessionContext, schema: SchemaRef, file_path: &Path, table_name: &str, file_sort_order: Vec>, - with_unbounded_execution: bool, ) -> Result<()> { - // Mark infinite and provide schema: - let fifo_options = CsvReadOptions::new() - .schema(schema.as_ref()) - .mark_infinite(with_unbounded_execution); - // Get listing options: - let options_sort = fifo_options - .to_listing_options(&ctx.copied_config()) - .with_file_sort_order(file_sort_order); + let config = + StreamConfig::new_file(schema, file_path.into()).with_order(file_sort_order); + // Register table: - ctx.register_listing_table( - table_name, - file_path.as_os_str().to_str().unwrap(), - options_sort, - Some(schema), - None, - ) - .await?; + ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; Ok(()) } diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index d3a1f9c1ef7c..336a6804637a 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -35,6 +35,9 @@ use crate::physical_plan::filter::FilterExec; use crate::physical_plan::metrics::MetricsSet; use crate::physical_plan::ExecutionPlan; use crate::prelude::{Expr, SessionConfig}; + +use datafusion_common::Statistics; + use object_store::path::Path; use object_store::ObjectMeta; use parquet::arrow::ArrowWriter; @@ -110,6 +113,7 @@ impl TestParquetFile { last_modified: Default::default(), size, e_tag: None, + version: None, }; Ok(Self { @@ -147,12 +151,11 @@ impl TestParquetFile { range: None, extensions: None, }]], - statistics: Default::default(), + statistics: Statistics::new_unknown(&self.schema), projection: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let df_schema = self.schema.clone().to_dfschema_ref()?; diff --git a/datafusion/core/tests/custom_sources.rs b/datafusion/core/tests/custom_sources.rs index 771da80aa6e7..a9ea5cc2a35c 100644 --- a/datafusion/core/tests/custom_sources.rs +++ b/datafusion/core/tests/custom_sources.rs @@ -15,41 +15,39 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + use arrow::array::{Int32Array, Int64Array}; use arrow::compute::kernels::aggregate; use arrow::datatypes::{DataType, Field, Int32Type, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::error::Result; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion::logical_expr::{ col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, }; -use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::{ - ColumnStatistics, DisplayAs, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, Statistics, + collect, ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, }; use datafusion::scalar::ScalarValue; -use datafusion::{ - datasource::{TableProvider, TableType}, - physical_plan::collect, -}; -use datafusion::{error::Result, physical_plan::DisplayFormatType}; - use datafusion_common::cast::as_primitive_array; use datafusion_common::project_schema; +use datafusion_common::stats::Precision; + +use async_trait::async_trait; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use futures::stream::Stream; -use std::any::Any; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; /// Also run all tests that are found in the `custom_sources_cases` directory mod custom_sources_cases; -use async_trait::async_trait; - -//// Custom source dataframe tests //// +//--- Custom source dataframe tests ---// struct CustomTableProvider; #[derive(Debug, Clone)] @@ -153,30 +151,28 @@ impl ExecutionPlan for CustomExecutionPlan { Ok(Box::pin(TestCustomRecordBatchStream { nb_batch: 1 })) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { let batch = TEST_CUSTOM_RECORD_BATCH!().unwrap(); - Statistics { - is_exact: true, - num_rows: Some(batch.num_rows()), - total_byte_size: None, - column_statistics: Some( - self.projection - .clone() - .unwrap_or_else(|| (0..batch.columns().len()).collect()) - .iter() - .map(|i| ColumnStatistics { - null_count: Some(batch.column(*i).null_count()), - min_value: Some(ScalarValue::Int32(aggregate::min( - as_primitive_array::(batch.column(*i)).unwrap(), - ))), - max_value: Some(ScalarValue::Int32(aggregate::max( - as_primitive_array::(batch.column(*i)).unwrap(), - ))), - ..Default::default() - }) - .collect(), - ), - } + Ok(Statistics { + num_rows: Precision::Exact(batch.num_rows()), + total_byte_size: Precision::Absent, + column_statistics: self + .projection + .clone() + .unwrap_or_else(|| (0..batch.columns().len()).collect()) + .iter() + .map(|i| ColumnStatistics { + null_count: Precision::Exact(batch.column(*i).null_count()), + min_value: Precision::Exact(ScalarValue::Int32(aggregate::min( + as_primitive_array::(batch.column(*i)).unwrap(), + ))), + max_value: Precision::Exact(ScalarValue::Int32(aggregate::max( + as_primitive_array::(batch.column(*i)).unwrap(), + ))), + ..Default::default() + }) + .collect(), + }) } } @@ -260,15 +256,15 @@ async fn optimizers_catch_all_statistics() { let physical_plan = df.create_physical_plan().await.unwrap(); - // when the optimization kicks in, the source is replaced by an EmptyExec + // when the optimization kicks in, the source is replaced by an PlaceholderRowExec assert!( - contains_empty_exec(Arc::clone(&physical_plan)), + contains_place_holder_exec(Arc::clone(&physical_plan)), "Expected aggregate_statistics optimizations missing: {physical_plan:?}" ); let expected = RecordBatch::try_new( Arc::new(Schema::new(vec![ - Field::new("COUNT(UInt8(1))", DataType::Int64, false), + Field::new("COUNT(*)", DataType::Int64, false), Field::new("MIN(test.c1)", DataType::Int32, false), Field::new("MAX(test.c1)", DataType::Int32, false), ])), @@ -287,12 +283,12 @@ async fn optimizers_catch_all_statistics() { assert_eq!(format!("{:?}", actual[0]), format!("{expected:?}")); } -fn contains_empty_exec(plan: Arc) -> bool { - if plan.as_any().is::() { +fn contains_place_holder_exec(plan: Arc) -> bool { + if plan.as_any().is::() { true } else if plan.children().len() != 1 { false } else { - contains_empty_exec(Arc::clone(&plan.children()[0])) + contains_place_holder_exec(Arc::clone(&plan.children()[0])) } } diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index 79214092fa57..e374abd6e891 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +use std::ops::Deref; +use std::sync::Arc; + use arrow::array::{Int32Builder, Int64Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use async_trait::async_trait; use datafusion::datasource::provider::{TableProvider, TableType}; use datafusion::error::Result; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; @@ -32,10 +34,10 @@ use datafusion::physical_plan::{ use datafusion::prelude::*; use datafusion::scalar::ScalarValue; use datafusion_common::cast::as_primitive_array; -use datafusion_common::{not_impl_err, DataFusionError}; +use datafusion_common::{internal_err, not_impl_err, DataFusionError}; use datafusion_expr::expr::{BinaryExpr, Cast}; -use std::ops::Deref; -use std::sync::Arc; + +use async_trait::async_trait; fn create_batch(value: i32, num_rows: usize) -> Result { let mut builder = Int32Builder::with_capacity(num_rows); @@ -96,9 +98,14 @@ impl ExecutionPlan for CustomPlan { fn with_new_children( self: Arc, - _: Vec>, + children: Vec>, ) -> Result> { - unreachable!() + // CustomPlan has no children + if children.is_empty() { + Ok(self) + } else { + internal_err!("Children cannot be replaced in {self:?}") + } } fn execute( @@ -112,10 +119,10 @@ impl ExecutionPlan for CustomPlan { ))) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { // here we could provide more accurate statistics // but we want to test the filter pushdown not the CBOs - Statistics::default() + Ok(Statistics::new_unknown(&self.schema())) } } @@ -142,10 +149,12 @@ impl TableProvider for CustomProvider { async fn scan( &self, _state: &SessionState, - _: Option<&Vec>, + projection: Option<&Vec>, filters: &[Expr], _: Option, ) -> Result> { + let empty = Vec::new(); + let projection = projection.unwrap_or(&empty); match &filters[0] { Expr::BinaryExpr(BinaryExpr { right, .. }) => { let int_value = match &**right { @@ -175,7 +184,10 @@ impl TableProvider for CustomProvider { }; Ok(Arc::new(CustomPlan { - schema: self.zero_batch.schema(), + schema: match projection.is_empty() { + true => Arc::new(Schema::empty()), + false => self.zero_batch.schema(), + }, batches: match int_value { 0 => vec![self.zero_batch.clone()], 1 => vec![self.one_batch.clone()], @@ -184,7 +196,10 @@ impl TableProvider for CustomProvider { })) } _ => Ok(Arc::new(CustomPlan { - schema: self.zero_batch.schema(), + schema: match projection.is_empty() { + true => Arc::new(Schema::empty()), + false => self.zero_batch.schema(), + }, batches: vec![], })), } diff --git a/datafusion/core/tests/custom_sources_cases/statistics.rs b/datafusion/core/tests/custom_sources_cases/statistics.rs index 43e6c8851ec4..f0985f554654 100644 --- a/datafusion/core/tests/custom_sources_cases/statistics.rs +++ b/datafusion/core/tests/custom_sources_cases/statistics.rs @@ -34,7 +34,7 @@ use datafusion::{ use async_trait::async_trait; use datafusion::execution::context::{SessionState, TaskContext}; -use datafusion_common::project_schema; +use datafusion_common::{project_schema, stats::Precision}; /// This is a testing structure for statistics /// It will act both as a table provider and execution plan @@ -46,13 +46,10 @@ struct StatisticsValidation { impl StatisticsValidation { fn new(stats: Statistics, schema: SchemaRef) -> Self { - assert!( - stats - .column_statistics - .as_ref() - .map(|cols| cols.len() == schema.fields().len()) - .unwrap_or(true), - "if defined, the column statistics vector length should be the number of fields" + assert_eq!( + stats.column_statistics.len(), + schema.fields().len(), + "the column statistics vector length should be the number of fields" ); Self { stats, schema } } @@ -94,17 +91,16 @@ impl TableProvider for StatisticsValidation { let current_stat = self.stats.clone(); - let proj_col_stats = current_stat - .column_statistics - .map(|col_stat| projection.iter().map(|i| col_stat[*i].clone()).collect()); - + let proj_col_stats = projection + .iter() + .map(|i| current_stat.column_statistics[*i].clone()) + .collect(); Ok(Arc::new(Self::new( Statistics { - is_exact: current_stat.is_exact, num_rows: current_stat.num_rows, column_statistics: proj_col_stats, // TODO stats: knowing the type of the new columns we can guess the output size - total_byte_size: None, + total_byte_size: Precision::Absent, }, projected_schema, ))) @@ -166,8 +162,8 @@ impl ExecutionPlan for StatisticsValidation { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Statistics { - self.stats.clone() + fn statistics(&self) -> Result { + Ok(self.stats.clone()) } } @@ -182,23 +178,22 @@ fn init_ctx(stats: Statistics, schema: Schema) -> Result { fn fully_defined() -> (Statistics, Schema) { ( Statistics { - num_rows: Some(13), - is_exact: true, - total_byte_size: None, // ignore byte size for now - column_statistics: Some(vec![ + num_rows: Precision::Exact(13), + total_byte_size: Precision::Absent, // ignore byte size for now + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(2), - max_value: Some(ScalarValue::Int32(Some(1023))), - min_value: Some(ScalarValue::Int32(Some(-24))), - null_count: Some(0), + distinct_count: Precision::Exact(2), + max_value: Precision::Exact(ScalarValue::Int32(Some(1023))), + min_value: Precision::Exact(ScalarValue::Int32(Some(-24))), + null_count: Precision::Exact(0), }, ColumnStatistics { - distinct_count: Some(13), - max_value: Some(ScalarValue::Int64(Some(5486))), - min_value: Some(ScalarValue::Int64(Some(-6783))), - null_count: Some(5), + distinct_count: Precision::Exact(13), + max_value: Precision::Exact(ScalarValue::Int64(Some(5486))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-6783))), + null_count: Precision::Exact(5), }, - ]), + ], }, Schema::new(vec![ Field::new("c1", DataType::Int32, false), @@ -216,7 +211,7 @@ async fn sql_basic() -> Result<()> { let physical_plan = df.create_physical_plan().await.unwrap(); // the statistics should be those of the source - assert_eq!(stats, physical_plan.statistics()); + assert_eq!(stats, physical_plan.statistics()?); Ok(()) } @@ -232,10 +227,8 @@ async fn sql_filter() -> Result<()> { .unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); - - let stats = physical_plan.statistics(); - assert!(!stats.is_exact); - assert_eq!(stats.num_rows, Some(1)); + let stats = physical_plan.statistics()?; + assert_eq!(stats.num_rows, Precision::Inexact(1)); Ok(()) } @@ -243,6 +236,7 @@ async fn sql_filter() -> Result<()> { #[tokio::test] async fn sql_limit() -> Result<()> { let (stats, schema) = fully_defined(); + let col_stats = Statistics::unknown_column(&schema); let ctx = init_ctx(stats.clone(), schema)?; let df = ctx.sql("SELECT * FROM stats_table LIMIT 5").await.unwrap(); @@ -251,11 +245,11 @@ async fn sql_limit() -> Result<()> { // we loose all statistics except the for number of rows which becomes the limit assert_eq!( Statistics { - num_rows: Some(5), - is_exact: true, - ..Default::default() + num_rows: Precision::Exact(5), + column_statistics: col_stats, + total_byte_size: Precision::Absent }, - physical_plan.statistics() + physical_plan.statistics()? ); let df = ctx @@ -264,7 +258,7 @@ async fn sql_limit() -> Result<()> { .unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); // when the limit is larger than the original number of lines, statistics remain unchanged - assert_eq!(stats, physical_plan.statistics()); + assert_eq!(stats, physical_plan.statistics()?); Ok(()) } @@ -281,13 +275,12 @@ async fn sql_window() -> Result<()> { let physical_plan = df.create_physical_plan().await.unwrap(); - let result = physical_plan.statistics(); + let result = physical_plan.statistics()?; assert_eq!(stats.num_rows, result.num_rows); - assert!(result.column_statistics.is_some()); - let col_stats = result.column_statistics.unwrap(); + let col_stats = result.column_statistics; assert_eq!(2, col_stats.len()); - assert_eq!(stats.column_statistics.unwrap()[1], col_stats[0]); + assert_eq!(stats.column_statistics[1], col_stats[0]); Ok(()) } diff --git a/datafusion/core/tests/data/aggregate_agg_multi_order.csv b/datafusion/core/tests/data/aggregate_agg_multi_order.csv new file mode 100644 index 000000000000..e9a65ceee4aa --- /dev/null +++ b/datafusion/core/tests/data/aggregate_agg_multi_order.csv @@ -0,0 +1,11 @@ +c1,c2,c3 +1,20,0 +2,20,1 +3,10,2 +4,10,3 +5,30,4 +6,30,5 +7,30,6 +8,30,7 +9,30,8 +10,10,9 \ No newline at end of file diff --git a/datafusion/core/tests/data/empty.json b/datafusion/core/tests/data/empty.json new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/datafusion/core/tests/data/escape.csv b/datafusion/core/tests/data/escape.csv new file mode 100644 index 000000000000..331a1e697329 --- /dev/null +++ b/datafusion/core/tests/data/escape.csv @@ -0,0 +1,11 @@ +c1,c2 +"id0","value\"0" +"id1","value\"1" +"id2","value\"2" +"id3","value\"3" +"id4","value\"4" +"id5","value\"5" +"id6","value\"6" +"id7","value\"7" +"id8","value\"8" +"id9","value\"9" diff --git a/datafusion/core/tests/data/parquet_map.parquet b/datafusion/core/tests/data/parquet_map.parquet new file mode 100644 index 000000000000..e7ffb5115c44 Binary files /dev/null and b/datafusion/core/tests/data/parquet_map.parquet differ diff --git a/datafusion/core/tests/data/quote.csv b/datafusion/core/tests/data/quote.csv new file mode 100644 index 000000000000..d81488436409 --- /dev/null +++ b/datafusion/core/tests/data/quote.csv @@ -0,0 +1,11 @@ +c1,c2 +~id0~,~value0~ +~id1~,~value1~ +~id2~,~value2~ +~id3~,~value3~ +~id4~,~value4~ +~id5~,~value5~ +~id6~,~value6~ +~id7~,~value7~ +~id8~,~value8~ +~id9~,~value9~ diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 9677003ec226..fe56fc22ea8c 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -31,6 +31,7 @@ use datafusion::prelude::*; use datafusion::execution::context::SessionContext; use datafusion::assert_batches_eq; +use datafusion_expr::expr::Alias; use datafusion_expr::{approx_median, cast}; async fn create_test_table() -> Result { @@ -186,6 +187,25 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { assert_batches_eq!(expected, &batches); + // the arg2 parameter is a complex expr, but it can be evaluated to the literal value + let alias_expr = Expr::Alias(Alias::new( + cast(lit(0.5), DataType::Float32), + None::<&str>, + "arg_2".to_string(), + )); + let expr = approx_percentile_cont(col("b"), alias_expr); + let df = create_test_table().await?; + let expected = [ + "+--------------------------------------+", + "| APPROX_PERCENTILE_CONT(test.b,arg_2) |", + "+--------------------------------------+", + "| 10 |", + "+--------------------------------------+", + ]; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_batches_eq!(expected, &batches); + Ok(()) } diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index e1982761f04c..cca23ac6847c 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -39,14 +39,13 @@ use datafusion::prelude::JoinType; use datafusion::prelude::{CsvReadOptions, ParquetReadOptions}; use datafusion::test_util::parquet_test_data; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; -use datafusion_common::{DataFusionError, ScalarValue, UnnestOptions}; +use datafusion_common::{assert_contains, DataFusionError, ScalarValue, UnnestOptions}; use datafusion_execution::config::SessionConfig; use datafusion_expr::expr::{GroupingSet, Sort}; -use datafusion_expr::Expr::Wildcard; use datafusion_expr::{ array_agg, avg, col, count, exists, expr, in_subquery, lit, max, out_ref_col, - scalar_subquery, sum, AggregateFunction, Expr, ExprSchemable, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunction, + scalar_subquery, sum, wildcard, AggregateFunction, Expr, ExprSchemable, WindowFrame, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_physical_expr::var_provider::{VarProvider, VarType}; @@ -64,8 +63,8 @@ async fn test_count_wildcard_on_sort() -> Result<()> { let df_results = ctx .table("t1") .await? - .aggregate(vec![col("b")], vec![count(Wildcard)])? - .sort(vec![count(Wildcard).sort(true, false)])? + .aggregate(vec![col("b")], vec![count(wildcard())])? + .sort(vec![count(wildcard()).sort(true, false)])? .explain(false, false)? .collect() .await?; @@ -99,8 +98,8 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { Arc::new( ctx.table("t2") .await? - .aggregate(vec![], vec![count(Expr::Wildcard)])? - .select(vec![count(Expr::Wildcard)])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![count(wildcard())])? .into_unoptimized_plan(), // Usually, into_optimized_plan() should be used here, but due to // https://github.com/apache/arrow-datafusion/issues/5771, @@ -136,8 +135,8 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { .filter(exists(Arc::new( ctx.table("t2") .await? - .aggregate(vec![], vec![count(Expr::Wildcard)])? - .select(vec![count(Expr::Wildcard)])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![count(wildcard())])? .into_unoptimized_plan(), // Usually, into_optimized_plan() should be used here, but due to // https://github.com/apache/arrow-datafusion/issues/5771, @@ -171,8 +170,8 @@ async fn test_count_wildcard_on_window() -> Result<()> { .table("t1") .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Count), - vec![Expr::Wildcard], + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], WindowFrame { @@ -202,17 +201,17 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { let sql_results = ctx .sql("select count(*) from t1") .await? - .select(vec![count(Expr::Wildcard)])? + .select(vec![count(wildcard())])? .explain(false, false)? .collect() .await?; - // add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node. + // add `.select(vec![count(wildcard())])?` to make sure we can analyze all node instead of just top node. let df_results = ctx .table("t1") .await? - .aggregate(vec![], vec![count(Expr::Wildcard)])? - .select(vec![count(Expr::Wildcard)])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![count(wildcard())])? .explain(false, false)? .collect() .await?; @@ -248,8 +247,8 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { ctx.table("t2") .await? .filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))? - .aggregate(vec![], vec![count(Wildcard)])? - .select(vec![col(count(Wildcard).to_string())])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![col(count(wildcard()).to_string())])? .into_unoptimized_plan(), )) .gt(lit(ScalarValue::UInt8(Some(0)))), @@ -1324,6 +1323,113 @@ async fn unnest_array_agg() -> Result<()> { Ok(()) } +#[tokio::test] +async fn unnest_with_redundant_columns() -> Result<()> { + let mut shape_id_builder = UInt32Builder::new(); + let mut tag_id_builder = UInt32Builder::new(); + + for shape_id in 1..=3 { + for tag_id in 1..=3 { + shape_id_builder.append_value(shape_id as u32); + tag_id_builder.append_value((shape_id * 10 + tag_id) as u32); + } + } + + let batch = RecordBatch::try_from_iter(vec![ + ("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef), + ("tag_id", Arc::new(tag_id_builder.finish()) as ArrayRef), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("shapes", batch)?; + let df = ctx.table("shapes").await?; + + let results = df.clone().collect().await?; + let expected = vec![ + "+----------+--------+", + "| shape_id | tag_id |", + "+----------+--------+", + "| 1 | 11 |", + "| 1 | 12 |", + "| 1 | 13 |", + "| 2 | 21 |", + "| 2 | 22 |", + "| 2 | 23 |", + "| 3 | 31 |", + "| 3 | 32 |", + "| 3 | 33 |", + "+----------+--------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + // Doing an `array_agg` by `shape_id` produces: + let df = df + .clone() + .aggregate( + vec![col("shape_id")], + vec![array_agg(col("shape_id")).alias("shape_id2")], + )? + .unnest_column("shape_id2")? + .select(vec![col("shape_id")])?; + + let optimized_plan = df.clone().into_optimized_plan()?; + let expected = vec![ + "Projection: shapes.shape_id [shape_id:UInt32]", + " Unnest: shape_id2 [shape_id:UInt32, shape_id2:UInt32;N]", + " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", + " TableScan: shapes projection=[shape_id] [shape_id:UInt32]", + ]; + + let formatted = optimized_plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let results = df.collect().await?; + let expected = [ + "+----------+", + "| shape_id |", + "+----------+", + "| 1 |", + "| 1 |", + "| 1 |", + "| 2 |", + "| 2 |", + "| 2 |", + "| 3 |", + "| 3 |", + "| 3 |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn unnest_analyze_metrics() -> Result<()> { + const NUM_ROWS: usize = 5; + + let df = table_with_nested_types(NUM_ROWS).await?; + let results = df + .unnest_column("tags")? + .explain(false, true)? + .collect() + .await?; + let formatted = arrow::util::pretty::pretty_format_batches(&results) + .unwrap() + .to_string(); + assert_contains!(&formatted, "elapsed_compute="); + assert_contains!(&formatted, "input_batches=1"); + assert_contains!(&formatted, "input_rows=5"); + assert_contains!(&formatted, "output_rows=10"); + assert_contains!(&formatted, "output_batches=1"); + + Ok(()) +} + async fn create_test_table(name: &str) -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), @@ -1567,7 +1673,7 @@ async fn use_var_provider() -> Result<()> { let config = SessionConfig::new() .with_target_partitions(4) .set_bool("datafusion.optimizer.skip_failed_rules", false); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); ctx.register_table("csv_table", mem_table)?; ctx.register_variable(VarType::UserDefined, Arc::new(HardcodedIntProvider {})); diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index 754389a61433..93c7f7368065 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -17,42 +17,48 @@ //! This test demonstrates the DataFusion FIFO capabilities. //! -#[cfg(not(target_os = "windows"))] +#[cfg(target_family = "unix")] #[cfg(test)] mod unix_test { - use arrow::array::Array; - use arrow::csv::ReaderBuilder; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion::test_util::register_unbounded_file_with_ordering; - use datafusion::{ - prelude::{CsvReadOptions, SessionConfig, SessionContext}, - test_util::{aggr_test_schema, arrow_test_data}, - }; - use datafusion_common::{exec_err, DataFusionError, Result}; - use futures::StreamExt; - use itertools::enumerate; - use nix::sys::stat; - use nix::unistd; - use rstest::*; use std::fs::{File, OpenOptions}; use std::io::Write; use std::path::PathBuf; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::thread; - use std::thread::JoinHandle; use std::time::{Duration, Instant}; + + use arrow::array::Array; + use arrow::csv::ReaderBuilder; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::SchemaRef; + use futures::StreamExt; + use nix::sys::stat; + use nix::unistd; use tempfile::TempDir; + use tokio::task::{spawn_blocking, JoinHandle}; - // ! For the sake of the test, do not alter the numbers. ! - // Session batch size - const TEST_BATCH_SIZE: usize = 20; - // Number of lines written to FIFO - const TEST_DATA_SIZE: usize = 20_000; - // Number of lines what can be joined. Each joinable key produced 20 lines with - // aggregate_test_100 dataset. We will use these joinable keys for understanding - // incremental execution. - const TEST_JOIN_RATIO: f64 = 0.01; + use datafusion::datasource::stream::{StreamConfig, StreamTable}; + use datafusion::datasource::TableProvider; + use datafusion::{ + prelude::{CsvReadOptions, SessionConfig, SessionContext}, + test_util::{aggr_test_schema, arrow_test_data}, + }; + use datafusion_common::{exec_err, DataFusionError, Result}; + use datafusion_expr::Expr; + + /// Makes a TableProvider for a fifo file + fn fifo_table( + schema: SchemaRef, + path: impl Into, + sort: Vec>, + ) -> Arc { + let config = StreamConfig::new_file(schema, path.into()) + .with_order(sort) + .with_batch_size(TEST_BATCH_SIZE) + .with_header(true); + Arc::new(StreamTable::new(Arc::new(config))) + } fn create_fifo_file(tmp_dir: &TempDir, file_name: &str) -> Result { let file_path = tmp_dir.path().join(file_name); @@ -86,26 +92,57 @@ mod unix_test { Ok(()) } + fn create_writing_thread( + file_path: PathBuf, + header: String, + lines: Vec, + waiting_lock: Arc, + wait_until: usize, + ) -> JoinHandle<()> { + // Timeout for a long period of BrokenPipe error + let broken_pipe_timeout = Duration::from_secs(10); + let sa = file_path.clone(); + // Spawn a new thread to write to the FIFO file + spawn_blocking(move || { + let file = OpenOptions::new().write(true).open(sa).unwrap(); + // Reference time to use when deciding to fail the test + let execution_start = Instant::now(); + write_to_fifo(&file, &header, execution_start, broken_pipe_timeout).unwrap(); + for (cnt, line) in lines.iter().enumerate() { + while waiting_lock.load(Ordering::SeqCst) && cnt > wait_until { + thread::sleep(Duration::from_millis(50)); + } + write_to_fifo(&file, line, execution_start, broken_pipe_timeout).unwrap(); + } + drop(file); + }) + } + + // ! For the sake of the test, do not alter the numbers. ! + // Session batch size + const TEST_BATCH_SIZE: usize = 20; + // Number of lines written to FIFO + const TEST_DATA_SIZE: usize = 20_000; + // Number of lines what can be joined. Each joinable key produced 20 lines with + // aggregate_test_100 dataset. We will use these joinable keys for understanding + // incremental execution. + const TEST_JOIN_RATIO: f64 = 0.01; + // This test provides a relatively realistic end-to-end scenario where // we swap join sides to accommodate a FIFO source. - #[rstest] - #[timeout(std::time::Duration::from_secs(30))] #[tokio::test(flavor = "multi_thread", worker_threads = 8)] - async fn unbounded_file_with_swapped_join( - #[values(true, false)] unbounded_file: bool, - ) -> Result<()> { + async fn unbounded_file_with_swapped_join() -> Result<()> { // Create session context let config = SessionConfig::new() .with_batch_size(TEST_BATCH_SIZE) .with_collect_statistics(false) .with_target_partitions(1); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); // To make unbounded deterministic - let waiting = Arc::new(AtomicBool::new(unbounded_file)); + let waiting = Arc::new(AtomicBool::new(true)); // Create a new temporary FIFO file let tmp_dir = TempDir::new()?; - let fifo_path = - create_fifo_file(&tmp_dir, &format!("fifo_{unbounded_file:?}.csv"))?; + let fifo_path = create_fifo_file(&tmp_dir, "fifo_unbounded.csv")?; // Execution can calculated at least one RecordBatch after the number of // "joinable_lines_length" lines are read. let joinable_lines_length = @@ -129,7 +166,7 @@ mod unix_test { "a1,a2\n".to_owned(), lines, waiting.clone(), - joinable_lines_length, + joinable_lines_length * 2, ); // Data Schema @@ -137,15 +174,10 @@ mod unix_test { Field::new("a1", DataType::Utf8, false), Field::new("a2", DataType::UInt32, false), ])); - // Create a file with bounded or unbounded flag. - ctx.register_csv( - "left", - fifo_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new() - .schema(schema.as_ref()) - .mark_infinite(unbounded_file), - ) - .await?; + + let provider = fifo_table(schema, fifo_path, vec![]); + ctx.register_table("left", provider).unwrap(); + // Register right table let schema = aggr_test_schema(); let test_data = arrow_test_data(); @@ -161,7 +193,7 @@ mod unix_test { while (stream.next().await).is_some() { waiting.store(false, Ordering::SeqCst); } - task.join().unwrap(); + task.await.unwrap(); Ok(()) } @@ -172,46 +204,17 @@ mod unix_test { Equal, } - fn create_writing_thread( - file_path: PathBuf, - header: String, - lines: Vec, - waiting_lock: Arc, - wait_until: usize, - ) -> JoinHandle<()> { - // Timeout for a long period of BrokenPipe error - let broken_pipe_timeout = Duration::from_secs(10); - // Spawn a new thread to write to the FIFO file - thread::spawn(move || { - let file = OpenOptions::new().write(true).open(file_path).unwrap(); - // Reference time to use when deciding to fail the test - let execution_start = Instant::now(); - write_to_fifo(&file, &header, execution_start, broken_pipe_timeout).unwrap(); - for (cnt, line) in enumerate(lines) { - while waiting_lock.load(Ordering::SeqCst) && cnt > wait_until { - thread::sleep(Duration::from_millis(50)); - } - write_to_fifo(&file, &line, execution_start, broken_pipe_timeout) - .unwrap(); - } - drop(file); - }) - } - // This test provides a relatively realistic end-to-end scenario where // we change the join into a [SymmetricHashJoin] to accommodate two // unbounded (FIFO) sources. - #[rstest] - #[timeout(std::time::Duration::from_secs(30))] - #[tokio::test(flavor = "multi_thread")] - #[ignore] + #[tokio::test] async fn unbounded_file_with_symmetric_join() -> Result<()> { // Create session context let config = SessionConfig::new() .with_batch_size(TEST_BATCH_SIZE) .set_bool("datafusion.execution.coalesce_batches", false) .with_target_partitions(1); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); // Tasks let mut tasks: Vec> = vec![]; @@ -254,47 +257,30 @@ mod unix_test { Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); + // Specify the ordering: - let file_sort_order = vec![[datafusion_expr::col("a1")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>()]; + let order = vec![vec![datafusion_expr::col("a1").sort(true, false)]]; + // Set unbounded sorted files read configuration - register_unbounded_file_with_ordering( - &ctx, - schema.clone(), - &left_fifo, - "left", - file_sort_order.clone(), - true, - ) - .await?; - register_unbounded_file_with_ordering( - &ctx, - schema, - &right_fifo, - "right", - file_sort_order, - true, - ) - .await?; + let provider = fifo_table(schema.clone(), left_fifo, order.clone()); + ctx.register_table("left", provider)?; + + let provider = fifo_table(schema.clone(), right_fifo, order); + ctx.register_table("right", provider)?; + // Execute the query, with no matching rows. (since key is modulus 10) let df = ctx .sql( "SELECT - t1.a1, - t1.a2, - t2.a1, - t2.a2 - FROM - left as t1 FULL - JOIN right as t2 ON t1.a2 = t2.a2 - AND t1.a1 > t2.a1 + 4 - AND t1.a1 < t2.a1 + 9", + t1.a1, + t1.a2, + t2.a1, + t2.a2 + FROM + left as t1 FULL + JOIN right as t2 ON t1.a2 = t2.a2 + AND t1.a1 > t2.a1 + 4 + AND t1.a1 < t2.a1 + 9", ) .await?; let mut stream = df.execute_stream().await?; @@ -313,7 +299,8 @@ mod unix_test { }; operations.push(op); } - tasks.into_iter().for_each(|jh| jh.join().unwrap()); + futures::future::try_join_all(tasks).await.unwrap(); + // The SymmetricHashJoin executor produces FULL join results at every // pruning, which happens before it reaches the end of input and more // than once. In this test, we feed partially joinable data to both @@ -342,7 +329,7 @@ mod unix_test { let waiting_thread = waiting.clone(); // create local execution context let config = SessionConfig::new().with_batch_size(TEST_BATCH_SIZE); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); // Create a new temporary FIFO file let tmp_dir = TempDir::new()?; let source_fifo_path = create_fifo_file(&tmp_dir, "source.csv")?; @@ -368,8 +355,9 @@ mod unix_test { // Prevent move let (sink_fifo_path_thread, sink_display_fifo_path) = (sink_fifo_path.clone(), sink_fifo_path.display()); + // Spawn a new thread to read sink EXTERNAL TABLE. - tasks.push(thread::spawn(move || { + tasks.push(spawn_blocking(move || { let file = File::open(sink_fifo_path_thread).unwrap(); let schema = Arc::new(Schema::new(vec![ Field::new("a1", DataType::Utf8, false), @@ -377,7 +365,6 @@ mod unix_test { ])); let mut reader = ReaderBuilder::new(schema) - .has_header(true) .with_batch_size(TEST_BATCH_SIZE) .build(file) .map_err(|e| DataFusionError::Internal(e.to_string())) @@ -389,38 +376,35 @@ mod unix_test { })); // register second csv file with the SQL (create an empty file if not found) ctx.sql(&format!( - "CREATE EXTERNAL TABLE source_table ( + "CREATE UNBOUNDED EXTERNAL TABLE source_table ( a1 VARCHAR NOT NULL, a2 INT NOT NULL ) STORED AS CSV WITH HEADER ROW - OPTIONS ('UNBOUNDED' 'TRUE') LOCATION '{source_display_fifo_path}'" )) .await?; // register csv file with the SQL ctx.sql(&format!( - "CREATE EXTERNAL TABLE sink_table ( + "CREATE UNBOUNDED EXTERNAL TABLE sink_table ( a1 VARCHAR NOT NULL, a2 INT NOT NULL ) STORED AS CSV WITH HEADER ROW - OPTIONS ('UNBOUNDED' 'TRUE') LOCATION '{sink_display_fifo_path}'" )) .await?; let df = ctx - .sql( - "INSERT INTO sink_table - SELECT a1, a2 FROM source_table", - ) + .sql("INSERT INTO sink_table SELECT a1, a2 FROM source_table") .await?; + + // Start execution df.collect().await?; - tasks.into_iter().for_each(|jh| jh.join().unwrap()); + futures::future::try_join_all(tasks).await.unwrap(); Ok(()) } } diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index a0e9a50a22ae..9069dbbd5850 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -35,38 +35,33 @@ use datafusion_physical_expr::expressions::{col, Sum}; use datafusion_physical_expr::{AggregateExpr, PhysicalSortExpr}; use test_utils::add_empty_batches; -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test(flavor = "multi_thread", worker_threads = 8)] - async fn aggregate_test() { - let test_cases = vec![ - vec!["a"], - vec!["b", "a"], - vec!["c", "a"], - vec!["c", "b", "a"], - vec!["d", "a"], - vec!["d", "b", "a"], - vec!["d", "c", "a"], - vec!["d", "c", "b", "a"], - ]; - let n = 300; - let distincts = vec![10, 20]; - for distinct in distincts { - let mut handles = Vec::new(); - for i in 0..n { - let test_idx = i % test_cases.len(); - let group_by_columns = test_cases[test_idx].clone(); - let job = tokio::spawn(run_aggregate_test( - make_staggered_batches::(1000, distinct, i as u64), - group_by_columns, - )); - handles.push(job); - } - for job in handles { - job.await.unwrap(); - } +#[tokio::test(flavor = "multi_thread", worker_threads = 8)] +async fn aggregate_test() { + let test_cases = vec![ + vec!["a"], + vec!["b", "a"], + vec!["c", "a"], + vec!["c", "b", "a"], + vec!["d", "a"], + vec!["d", "b", "a"], + vec!["d", "c", "a"], + vec!["d", "c", "b", "a"], + ]; + let n = 300; + let distincts = vec![10, 20]; + for distinct in distincts { + let mut handles = Vec::new(); + for i in 0..n { + let test_idx = i % test_cases.len(); + let group_by_columns = test_cases[test_idx].clone(); + let job = tokio::spawn(run_aggregate_test( + make_staggered_batches::(1000, distinct, i as u64), + group_by_columns, + )); + handles.push(job); + } + for job in handles { + job.await.unwrap(); } } } @@ -77,7 +72,7 @@ mod tests { async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str>) { let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); - let ctx = SessionContext::with_config(session_config); + let ctx = SessionContext::new_with_config(session_config); let mut sort_keys = vec![]; for ordering_col in ["a", "b", "c"] { sort_keys.push(PhysicalSortExpr { @@ -114,7 +109,6 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str group_by.clone(), aggregate_expr.clone(), vec![None], - vec![None], running_source, schema.clone(), ) @@ -127,7 +121,6 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str group_by.clone(), aggregate_expr.clone(), vec![None], - vec![None], usual_source, schema.clone(), ) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 9b741440ff13..ac86364f4255 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -102,7 +102,7 @@ async fn run_join_test( let batch_sizes = [1, 2, 7, 49, 50, 51, 100]; for batch_size in batch_sizes { let session_config = SessionConfig::new().with_batch_size(batch_size); - let ctx = SessionContext::with_config(session_config); + let ctx = SessionContext::new_with_config(session_config); let task_ctx = ctx.task_ctx(); let schema1 = input1[0].schema(); diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs new file mode 100644 index 000000000000..9889ce2ae562 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -0,0 +1,349 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fuzz Test for Sort + Fetch/Limit (TopK!) + +use arrow::compute::concat_batches; +use arrow::util::pretty::pretty_format_batches; +use arrow::{array::Int32Array, record_batch::RecordBatch}; +use arrow_array::{Float64Array, Int64Array, StringArray}; +use arrow_schema::SchemaRef; +use datafusion::datasource::MemTable; +use datafusion::prelude::SessionContext; +use datafusion_common::assert_contains; +use rand::{thread_rng, Rng}; +use std::sync::Arc; +use test_utils::stagger_batch; + +#[tokio::test] +async fn test_sort_topk_i32() { + run_limit_fuzz_test(SortedData::new_i32).await +} + +#[tokio::test] +async fn test_sort_topk_f64() { + run_limit_fuzz_test(SortedData::new_f64).await +} + +#[tokio::test] +async fn test_sort_topk_str() { + run_limit_fuzz_test(SortedData::new_str).await +} + +#[tokio::test] +async fn test_sort_topk_i64str() { + run_limit_fuzz_test(SortedData::new_i64str).await +} + +/// Run TopK fuzz tests the specified input data with different +/// different test functions so they can run in parallel) +async fn run_limit_fuzz_test(make_data: F) +where + F: Fn(usize) -> SortedData, +{ + let mut rng = thread_rng(); + for size in [10, 1_0000, 10_000, 100_000] { + let data = make_data(size); + // test various limits including some random ones + for limit in [1, 3, 7, 17, 10000, rng.gen_range(1..size * 2)] { + // limit can be larger than the number of rows in the input + run_limit_test(limit, &data).await; + } + } +} + +/// The data column(s) to use for the TopK test +/// +/// Each variants stores the input batches and the expected sorted values +/// compute the expected output for a given fetch (limit) value. +#[derive(Debug)] +enum SortedData { + // single Int32 column + I32 { + batches: Vec, + sorted: Vec>, + }, + /// Single Float64 column + F64 { + batches: Vec, + sorted: Vec>, + }, + /// Single sorted String column + Str { + batches: Vec, + sorted: Vec>, + }, + /// (i64, string) columns + I64Str { + batches: Vec, + sorted: Vec<(Option, Option)>, + }, +} + +impl SortedData { + /// Create an i32 column of random values, with the specified number of + /// rows, sorted the default + fn new_i32(size: usize) -> Self { + let mut rng = thread_rng(); + // have some repeats (approximately 1/3 of the values are the same) + let max = size as i32 / 3; + let data: Vec> = (0..size) + .map(|_| { + // no nulls for now + Some(rng.gen_range(0..max)) + }) + .collect(); + + let batches = stagger_batch(int32_batch(data.iter().cloned())); + + let mut sorted = data; + sorted.sort_unstable(); + + Self::I32 { batches, sorted } + } + + /// Create an f64 column of random values, with the specified number of + /// rows, sorted the default + fn new_f64(size: usize) -> Self { + let mut rng = thread_rng(); + let mut data: Vec> = (0..size / 3) + .map(|_| { + // no nulls for now + Some(rng.gen_range(0.0..1.0f64)) + }) + .collect(); + + // have some repeats (approximately 1/3 of the values are the same) + while data.len() < size { + data.push(data[rng.gen_range(0..data.len())]); + } + + let batches = stagger_batch(f64_batch(data.iter().cloned())); + + let mut sorted = data; + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); + + Self::F64 { batches, sorted } + } + + /// Create an string column of random values, with the specified number of + /// rows, sorted the default + fn new_str(size: usize) -> Self { + let mut rng = thread_rng(); + let mut data: Vec> = (0..size / 3) + .map(|_| { + // no nulls for now + Some(get_random_string(16)) + }) + .collect(); + + // have some repeats (approximately 1/3 of the values are the same) + while data.len() < size { + data.push(data[rng.gen_range(0..data.len())].clone()); + } + + let batches = stagger_batch(string_batch(data.iter())); + + let mut sorted = data; + sorted.sort_unstable(); + + Self::Str { batches, sorted } + } + + /// Create two columns of random values (int64, string), with the specified number of + /// rows, sorted the default + fn new_i64str(size: usize) -> Self { + let mut rng = thread_rng(); + + // 100 distinct values + let strings: Vec> = (0..100) + .map(|_| { + // no nulls for now + Some(get_random_string(16)) + }) + .collect(); + + // form inputs, with only 10 distinct integer values , to force collision checks + let data = (0..size) + .map(|_| { + ( + Some(rng.gen_range(0..10)), + strings[rng.gen_range(0..strings.len())].clone(), + ) + }) + .collect::>(); + + let batches = stagger_batch(i64string_batch(data.iter())); + + let mut sorted = data; + sorted.sort_unstable(); + + Self::I64Str { batches, sorted } + } + + /// Return top top `limit` values as a RecordBatch + fn topk_values(&self, limit: usize) -> RecordBatch { + match self { + Self::I32 { sorted, .. } => int32_batch(sorted.iter().take(limit).cloned()), + Self::F64 { sorted, .. } => f64_batch(sorted.iter().take(limit).cloned()), + Self::Str { sorted, .. } => string_batch(sorted.iter().take(limit)), + Self::I64Str { sorted, .. } => i64string_batch(sorted.iter().take(limit)), + } + } + + /// Return the input data to sort + fn batches(&self) -> Vec { + match self { + Self::I32 { batches, .. } => batches.clone(), + Self::F64 { batches, .. } => batches.clone(), + Self::Str { batches, .. } => batches.clone(), + Self::I64Str { batches, .. } => batches.clone(), + } + } + + /// Return the schema of the input data + fn schema(&self) -> SchemaRef { + match self { + Self::I32 { batches, .. } => batches[0].schema(), + Self::F64 { batches, .. } => batches[0].schema(), + Self::Str { batches, .. } => batches[0].schema(), + Self::I64Str { batches, .. } => batches[0].schema(), + } + } + + /// Return the sort expression to use for this data, depending on the type + fn sort_expr(&self) -> Vec { + match self { + Self::I32 { .. } | Self::F64 { .. } | Self::Str { .. } => { + vec![datafusion_expr::col("x").sort(true, true)] + } + Self::I64Str { .. } => { + vec![ + datafusion_expr::col("x").sort(true, true), + datafusion_expr::col("y").sort(true, true), + ] + } + } + } +} + +/// Create a record batch with a single column of type `Int32` named "x" +fn int32_batch(values: impl IntoIterator>) -> RecordBatch { + RecordBatch::try_from_iter(vec![( + "x", + Arc::new(Int32Array::from_iter(values.into_iter())) as _, + )]) + .unwrap() +} + +/// Create a record batch with a single column of type `Float64` named "x" +fn f64_batch(values: impl IntoIterator>) -> RecordBatch { + RecordBatch::try_from_iter(vec![( + "x", + Arc::new(Float64Array::from_iter(values.into_iter())) as _, + )]) + .unwrap() +} + +/// Create a record batch with a single column of type `StringArray` named "x" +fn string_batch<'a>(values: impl IntoIterator>) -> RecordBatch { + RecordBatch::try_from_iter(vec![( + "x", + Arc::new(StringArray::from_iter(values.into_iter())) as _, + )]) + .unwrap() +} + +/// Create a record batch with i64 column "x" and utf8 column "y" +fn i64string_batch<'a>( + values: impl IntoIterator, Option)> + Clone, +) -> RecordBatch { + let ints = values.clone().into_iter().map(|(i, _)| *i); + let strings = values.into_iter().map(|(_, s)| s); + RecordBatch::try_from_iter(vec![ + ("x", Arc::new(Int64Array::from_iter(ints)) as _), + ("y", Arc::new(StringArray::from_iter(strings)) as _), + ]) + .unwrap() +} + +/// Run the TopK test, sorting the input batches with the specified ftch +/// (limit) and compares the results to the expected values. +async fn run_limit_test(fetch: usize, data: &SortedData) { + let input = data.batches(); + let schema = data.schema(); + + let table = MemTable::try_new(schema, vec![input]).unwrap(); + + let ctx = SessionContext::new(); + let df = ctx + .read_table(Arc::new(table)) + .unwrap() + .sort(data.sort_expr()) + .unwrap() + .limit(0, Some(fetch)) + .unwrap(); + + // Verify the plan contains a TopK node + { + let explain = df + .clone() + .explain(false, false) + .unwrap() + .collect() + .await + .unwrap(); + let plan_text = pretty_format_batches(&explain).unwrap().to_string(); + let expected = format!("TopK(fetch={fetch})"); + assert_contains!(plan_text, expected); + } + + let results = df.collect().await.unwrap(); + let expected = data.topk_values(fetch); + + // Verify that all output batches conform to the specified batch size + let max_batch_size = ctx.copied_config().batch_size(); + for batch in &results { + assert!(batch.num_rows() <= max_batch_size); + } + + let results = concat_batches(&results[0].schema(), &results).unwrap(); + + let results = [results]; + let expected = [expected]; + + assert_eq!( + &expected, + &results, + "TopK mismatch fetch {fetch} \n\ + expected rows {}, actual rows {}.\ + \n\nExpected:\n{}\n\nActual:\n{}", + expected[0].num_rows(), + results[0].num_rows(), + pretty_format_batches(&expected).unwrap(), + pretty_format_batches(&results).unwrap(), + ); +} + +/// Return random ASCII String with len +fn get_random_string(len: usize) -> String { + rand::thread_rng() + .sample_iter(rand::distributions::Alphanumeric) + .take(len) + .map(char::from) + .collect() +} diff --git a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs index 6411f31be0ce..c38ff41f5783 100644 --- a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs @@ -118,7 +118,7 @@ async fn run_merge_test(input: Vec>) { let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); let session_config = SessionConfig::new().with_batch_size(batch_size); - let ctx = SessionContext::with_config(session_config); + let ctx = SessionContext::new_with_config(session_config); let task_ctx = ctx.task_ctx(); let collected = collect(merge, task_ctx).await.unwrap(); diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 140cf7e5c75b..83ec928ae229 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -19,5 +19,7 @@ mod aggregate_fuzz; mod join_fuzz; mod merge_fuzz; mod sort_fuzz; + +mod limit_fuzz; mod sort_preserving_repartition_fuzz; mod window_fuzz; diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index 6c427c7fb7b3..f4b4f16aa160 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -23,90 +23,55 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use datafusion::physical_plan::expressions::{col, PhysicalSortExpr}; +use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_execution::memory_pool::GreedyMemoryPool; +use datafusion_physical_expr::expressions::col; use rand::Rng; use std::sync::Arc; use test_utils::{batches_to_vec, partitions_to_sorted_vec}; +const KB: usize = 1 << 10; #[tokio::test] #[cfg_attr(tarpaulin, ignore)] async fn test_sort_1k_mem() { - SortTest::new() - .with_int32_batches(5) - .with_pool_size(10240) - .with_should_spill(false) - .run() - .await; - - SortTest::new() - .with_int32_batches(20000) - .with_pool_size(10240) - .with_should_spill(true) - .run() - .await; - - SortTest::new() - .with_int32_batches(1000000) - .with_pool_size(10240) - .with_should_spill(true) - .run() - .await; + for (batch_size, should_spill) in [(5, false), (20000, true), (1000000, true)] { + SortTest::new() + .with_int32_batches(batch_size) + .with_pool_size(10 * KB) + .with_should_spill(should_spill) + .run() + .await; + } } #[tokio::test] #[cfg_attr(tarpaulin, ignore)] async fn test_sort_100k_mem() { - SortTest::new() - .with_int32_batches(5) - .with_pool_size(102400) - .with_should_spill(false) - .run() - .await; - - SortTest::new() - .with_int32_batches(20000) - .with_pool_size(102400) - .with_should_spill(false) - .run() - .await; - - SortTest::new() - .with_int32_batches(1000000) - .with_pool_size(102400) - .with_should_spill(true) - .run() - .await; + for (batch_size, should_spill) in [(5, false), (20000, false), (1000000, true)] { + SortTest::new() + .with_int32_batches(batch_size) + .with_pool_size(100 * KB) + .with_should_spill(should_spill) + .run() + .await; + } } #[tokio::test] async fn test_sort_unlimited_mem() { - SortTest::new() - .with_int32_batches(5) - .with_pool_size(usize::MAX) - .with_should_spill(false) - .run() - .await; - - SortTest::new() - .with_int32_batches(20000) - .with_pool_size(usize::MAX) - .with_should_spill(false) - .run() - .await; - - SortTest::new() - .with_int32_batches(1000000) - .with_pool_size(usize::MAX) - .with_should_spill(false) - .run() - .await; + for (batch_size, should_spill) in [(5, false), (20000, false), (1000000, false)] { + SortTest::new() + .with_int32_batches(batch_size) + .with_pool_size(usize::MAX) + .with_should_spill(should_spill) + .run() + .await; + } } - #[derive(Debug, Default)] struct SortTest { input: Vec>, @@ -174,9 +139,9 @@ impl SortTest { let runtime_config = RuntimeConfig::new() .with_memory_pool(Arc::new(GreedyMemoryPool::new(pool_size))); let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap()); - SessionContext::with_config_rt(session_config, runtime) + SessionContext::new_with_config_rt(session_config, runtime) } else { - SessionContext::with_config(session_config) + SessionContext::new_with_config(session_config) }; let task_ctx = session_ctx.task_ctx(); diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 6304e01c6389..df6499e9b1e4 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -17,22 +17,273 @@ #[cfg(test)] mod sp_repartition_fuzz_tests { - use arrow::compute::concat_batches; - use arrow_array::{ArrayRef, Int64Array, RecordBatch}; - use arrow_schema::SortOptions; - use datafusion::physical_plan::memory::MemoryExec; - use datafusion::physical_plan::repartition::RepartitionExec; - use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; - use datafusion::physical_plan::{collect, ExecutionPlan, Partitioning}; - use datafusion::prelude::SessionContext; - use datafusion_execution::config::SessionConfig; - use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; - use rand::rngs::StdRng; - use rand::{Rng, SeedableRng}; use std::sync::Arc; + + use arrow::compute::{concat_batches, lexsort, SortColumn}; + use arrow_array::{ArrayRef, Int64Array, RecordBatch, UInt64Array}; + use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; + + use datafusion::physical_plan::{ + collect, + memory::MemoryExec, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet}, + repartition::RepartitionExec, + sorts::sort_preserving_merge::SortPreservingMergeExec, + sorts::streaming_merge::streaming_merge, + stream::RecordBatchStreamAdapter, + ExecutionPlan, Partitioning, + }; + use datafusion::prelude::SessionContext; + use datafusion_common::Result; + use datafusion_execution::{ + config::SessionConfig, memory_pool::MemoryConsumer, SendableRecordBatchStream, + }; + use datafusion_physical_expr::{ + expressions::{col, Column}, + EquivalenceProperties, PhysicalExpr, PhysicalSortExpr, + }; use test_utils::add_empty_batches; + use datafusion_physical_expr::equivalence::EquivalenceClass; + use itertools::izip; + use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; + + // Generate a schema which consists of 6 columns (a, b, c, d, e, f) + fn create_test_schema() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) + } + + /// Construct a schema with random ordering + /// among column a, b, c, d + /// where + /// Column [a=f] (e.g they are aliases). + /// Column e is constant. + fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; + + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + // Define a and f are aliases + eq_properties.add_equal_conditions(col_a, col_f); + // Column e has constant value. + eq_properties = eq_properties.add_constants([col_e.clone()]); + + // Randomly order columns for sorting + let mut rng = StdRng::seed_from_u64(seed); + let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted + + let options_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + while !remaining_exprs.is_empty() { + let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + remaining_exprs.shuffle(&mut rng); + + let ordering = remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: options_asc, + }) + .collect(); + + eq_properties.add_new_orderings([ordering]); + } + + Ok((test_schema, eq_properties)) + } + + // If we already generated a random result for one of the + // expressions in the equivalence classes. For other expressions in the same + // equivalence class use same result. This util gets already calculated result, when available. + fn get_representative_arr( + eq_group: &EquivalenceClass, + existing_vec: &[Option], + schema: SchemaRef, + ) -> Option { + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + if let Some(res) = &existing_vec[idx] { + return Some(res.clone()); + } + } + None + } + + // Generate a table that satisfies the given equivalence properties; i.e. + // equivalences, ordering equivalences, and constants. + fn generate_table_for_eq_properties( + eq_properties: &EquivalenceProperties, + n_elem: usize, + n_distinct: usize, + ) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + let schema = eq_properties.schema(); + let mut schema_vec = vec![None; schema.fields.len()]; + + // Utility closure to generate random array + let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { + let values: Vec = (0..num_elems) + .map(|_| rng.gen_range(0..max_val) as u64) + .collect(); + Arc::new(UInt64Array::from_iter_values(values)) + }; + + // Fill constant columns + for constant in eq_properties.constants() { + let col = constant.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = + Arc::new(UInt64Array::from_iter_values(vec![0; n_elem])) as ArrayRef; + schema_vec[idx] = Some(arr); + } + + // Fill columns based on ordering equivalences + for ordering in eq_properties.oeq_class().iter() { + let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering + .iter() + .map(|PhysicalSortExpr { expr, options }| { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = generate_random_array(n_elem, n_distinct); + ( + SortColumn { + values: arr, + options: Some(*options), + }, + idx, + ) + }) + .unzip(); + + let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + for (idx, arr) in izip!(indices, sort_arrs) { + schema_vec[idx] = Some(arr); + } + } + + // Fill columns based on equivalence groups + for eq_group in eq_properties.eq_group().iter() { + let representative_array = + get_representative_arr(eq_group, &schema_vec, schema.clone()) + .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); + + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + schema_vec[idx] = Some(representative_array.clone()); + } + } + + let res: Vec<_> = schema_vec + .into_iter() + .zip(schema.fields.iter()) + .map(|(elem, field)| { + ( + field.name(), + // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) + elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), + ) + }) + .collect(); + + Ok(RecordBatch::try_from_iter(res)?) + } + + // This test checks for whether during sort preserving merge we can preserve all of the valid orderings + // successfully. If at the input we have orderings [a ASC, b ASC], [c ASC, d ASC] + // After sort preserving merge orderings [a ASC, b ASC], [c ASC, d ASC] should still be valid. + #[tokio::test] + async fn stream_merge_multi_order_preserve() -> Result<()> { + const N_PARTITION: usize = 8; + const N_ELEM: usize = 25; + const N_DISTINCT: usize = 5; + const N_DIFF_SCHEMA: usize = 20; + + use datafusion::physical_plan::common::collect; + for seed in 0..N_DIFF_SCHEMA { + // Create a schema with random equivalence properties + let (_test_schema, eq_properties) = create_random_schema(seed as u64)?; + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEM, N_DISTINCT)?; + let schema = table_data_with_properties.schema(); + let streams: Vec = (0..N_PARTITION) + .map(|_idx| { + let batch = table_data_with_properties.clone(); + Box::pin(RecordBatchStreamAdapter::new( + schema.clone(), + futures::stream::once(async { Ok(batch) }), + )) as SendableRecordBatchStream + }) + .collect::>(); + + // Returns concatenated version of the all available orderings + let exprs = eq_properties + .oeq_class() + .output_ordering() + .unwrap_or_default(); + + let context = SessionContext::new().task_ctx(); + let mem_reservation = + MemoryConsumer::new("test".to_string()).register(context.memory_pool()); + + // Internally SortPreservingMergeExec uses this function for merging. + let res = streaming_merge( + streams, + schema, + &exprs, + BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0), + 1, + None, + mem_reservation, + )?; + let res = collect(res).await?; + // Contains the merged result. + let res = concat_batches(&res[0].schema(), &res)?; + + for ordering in eq_properties.oeq_class().iter() { + let err_msg = format!("error in eq properties: {:?}", eq_properties); + let sort_solumns = ordering + .iter() + .map(|sort_expr| sort_expr.evaluate_to_sort_column(&res)) + .collect::>>()?; + let orig_columns = sort_solumns + .iter() + .map(|sort_column| sort_column.values.clone()) + .collect::>(); + let sorted_columns = lexsort(&sort_solumns, None)?; + + // Make sure after merging ordering is still valid. + assert_eq!(orig_columns.len(), sorted_columns.len(), "{}", err_msg); + assert!( + izip!(orig_columns.into_iter(), sorted_columns.into_iter()) + .all(|(lhs, rhs)| { lhs == rhs }), + "{}", + err_msg + ) + } + } + Ok(()) + } + #[tokio::test(flavor = "multi_thread", worker_threads = 8)] async fn sort_preserving_repartition_test() { let seed_start = 0; @@ -93,7 +344,7 @@ mod sp_repartition_fuzz_tests { ) { let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); - let ctx = SessionContext::with_config(session_config); + let ctx = SessionContext::new_with_config(session_config); let mut sort_keys = vec![]; for ordering_col in ["a", "b", "c"] { sort_keys.push(PhysicalSortExpr { @@ -140,7 +391,7 @@ mod sp_repartition_fuzz_tests { Arc::new( RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(2)) .unwrap() - .with_preserve_order(true), + .with_preserve_order(), ) } @@ -159,7 +410,7 @@ mod sp_repartition_fuzz_tests { Arc::new( RepartitionExec::try_new(input, Partitioning::Hash(hash_expr, 2)) .unwrap() - .with_preserve_order(true), + .with_preserve_order(), ) } diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 3d103ee70ee8..3037b4857a3b 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -22,135 +22,128 @@ use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use hashbrown::HashMap; -use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; - use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ - create_window_expr, BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, -}; -use datafusion::physical_plan::{collect, ExecutionPlan}; -use datafusion_expr::{ - AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunction, + create_window_expr, BoundedWindowAggExec, WindowAggExec, }; - +use datafusion::physical_plan::{collect, ExecutionPlan, InputOrderMode}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::coerce_types; +use datafusion_expr::{ + AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, +}; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use test_utils::add_empty_batches; -#[cfg(test)] -mod tests { - use super::*; - use datafusion::physical_plan::windows::PartitionSearchMode::{ - Linear, PartiallySorted, Sorted, - }; +use hashbrown::HashMap; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; - #[tokio::test(flavor = "multi_thread", worker_threads = 16)] - async fn window_bounded_window_random_comparison() -> Result<()> { - // make_staggered_batches gives result sorted according to a, b, c - // In the test cases first entry represents partition by columns - // Second entry represents order by columns. - // Third entry represents search mode. - // In sorted mode physical plans are in the form for WindowAggExec - //``` - // WindowAggExec - // MemoryExec] - // ``` - // and in the form for BoundedWindowAggExec - // ``` - // BoundedWindowAggExec - // MemoryExec - // ``` - // In Linear and PartiallySorted mode physical plans are in the form for WindowAggExec - //``` - // WindowAggExec - // SortExec(required by window function) - // MemoryExec] - // ``` - // and in the form for BoundedWindowAggExec - // ``` - // BoundedWindowAggExec - // MemoryExec - // ``` - let test_cases = vec![ - (vec!["a"], vec!["a"], Sorted), - (vec!["a"], vec!["b"], Sorted), - (vec!["a"], vec!["a", "b"], Sorted), - (vec!["a"], vec!["b", "c"], Sorted), - (vec!["a"], vec!["a", "b", "c"], Sorted), - (vec!["b"], vec!["a"], Linear), - (vec!["b"], vec!["a", "b"], Linear), - (vec!["b"], vec!["a", "c"], Linear), - (vec!["b"], vec!["a", "b", "c"], Linear), - (vec!["c"], vec!["a"], Linear), - (vec!["c"], vec!["a", "b"], Linear), - (vec!["c"], vec!["a", "c"], Linear), - (vec!["c"], vec!["a", "b", "c"], Linear), - (vec!["b", "a"], vec!["a"], Sorted), - (vec!["b", "a"], vec!["b"], Sorted), - (vec!["b", "a"], vec!["c"], Sorted), - (vec!["b", "a"], vec!["a", "b"], Sorted), - (vec!["b", "a"], vec!["b", "c"], Sorted), - (vec!["b", "a"], vec!["a", "c"], Sorted), - (vec!["b", "a"], vec!["a", "b", "c"], Sorted), - (vec!["c", "b"], vec!["a"], Linear), - (vec!["c", "b"], vec!["a", "b"], Linear), - (vec!["c", "b"], vec!["a", "c"], Linear), - (vec!["c", "b"], vec!["a", "b", "c"], Linear), - (vec!["c", "a"], vec!["a"], PartiallySorted(vec![1])), - (vec!["c", "a"], vec!["b"], PartiallySorted(vec![1])), - (vec!["c", "a"], vec!["c"], PartiallySorted(vec![1])), - (vec!["c", "a"], vec!["a", "b"], PartiallySorted(vec![1])), - (vec!["c", "a"], vec!["b", "c"], PartiallySorted(vec![1])), - (vec!["c", "a"], vec!["a", "c"], PartiallySorted(vec![1])), - ( - vec!["c", "a"], - vec!["a", "b", "c"], - PartiallySorted(vec![1]), - ), - (vec!["c", "b", "a"], vec!["a"], Sorted), - (vec!["c", "b", "a"], vec!["b"], Sorted), - (vec!["c", "b", "a"], vec!["c"], Sorted), - (vec!["c", "b", "a"], vec!["a", "b"], Sorted), - (vec!["c", "b", "a"], vec!["b", "c"], Sorted), - (vec!["c", "b", "a"], vec!["a", "c"], Sorted), - (vec!["c", "b", "a"], vec!["a", "b", "c"], Sorted), - ]; - let n = 300; - let n_distincts = vec![10, 20]; - for n_distinct in n_distincts { - let mut handles = Vec::new(); - for i in 0..n { - let idx = i % test_cases.len(); - let (pb_cols, ob_cols, search_mode) = test_cases[idx].clone(); - let job = tokio::spawn(run_window_test( - make_staggered_batches::(1000, n_distinct, i as u64), - i as u64, - pb_cols, - ob_cols, - search_mode, - )); - handles.push(job); - } - for job in handles { - job.await.unwrap()?; - } +use datafusion_physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 16)] +async fn window_bounded_window_random_comparison() -> Result<()> { + // make_staggered_batches gives result sorted according to a, b, c + // In the test cases first entry represents partition by columns + // Second entry represents order by columns. + // Third entry represents search mode. + // In sorted mode physical plans are in the form for WindowAggExec + //``` + // WindowAggExec + // MemoryExec] + // ``` + // and in the form for BoundedWindowAggExec + // ``` + // BoundedWindowAggExec + // MemoryExec + // ``` + // In Linear and PartiallySorted mode physical plans are in the form for WindowAggExec + //``` + // WindowAggExec + // SortExec(required by window function) + // MemoryExec] + // ``` + // and in the form for BoundedWindowAggExec + // ``` + // BoundedWindowAggExec + // MemoryExec + // ``` + let test_cases = vec![ + (vec!["a"], vec!["a"], Sorted), + (vec!["a"], vec!["b"], Sorted), + (vec!["a"], vec!["a", "b"], Sorted), + (vec!["a"], vec!["b", "c"], Sorted), + (vec!["a"], vec!["a", "b", "c"], Sorted), + (vec!["b"], vec!["a"], Linear), + (vec!["b"], vec!["a", "b"], Linear), + (vec!["b"], vec!["a", "c"], Linear), + (vec!["b"], vec!["a", "b", "c"], Linear), + (vec!["c"], vec!["a"], Linear), + (vec!["c"], vec!["a", "b"], Linear), + (vec!["c"], vec!["a", "c"], Linear), + (vec!["c"], vec!["a", "b", "c"], Linear), + (vec!["b", "a"], vec!["a"], Sorted), + (vec!["b", "a"], vec!["b"], Sorted), + (vec!["b", "a"], vec!["c"], Sorted), + (vec!["b", "a"], vec!["a", "b"], Sorted), + (vec!["b", "a"], vec!["b", "c"], Sorted), + (vec!["b", "a"], vec!["a", "c"], Sorted), + (vec!["b", "a"], vec!["a", "b", "c"], Sorted), + (vec!["c", "b"], vec!["a"], Linear), + (vec!["c", "b"], vec!["a", "b"], Linear), + (vec!["c", "b"], vec!["a", "c"], Linear), + (vec!["c", "b"], vec!["a", "b", "c"], Linear), + (vec!["c", "a"], vec!["a"], PartiallySorted(vec![1])), + (vec!["c", "a"], vec!["b"], PartiallySorted(vec![1])), + (vec!["c", "a"], vec!["c"], PartiallySorted(vec![1])), + (vec!["c", "a"], vec!["a", "b"], PartiallySorted(vec![1])), + (vec!["c", "a"], vec!["b", "c"], PartiallySorted(vec![1])), + (vec!["c", "a"], vec!["a", "c"], PartiallySorted(vec![1])), + ( + vec!["c", "a"], + vec!["a", "b", "c"], + PartiallySorted(vec![1]), + ), + (vec!["c", "b", "a"], vec!["a"], Sorted), + (vec!["c", "b", "a"], vec!["b"], Sorted), + (vec!["c", "b", "a"], vec!["c"], Sorted), + (vec!["c", "b", "a"], vec!["a", "b"], Sorted), + (vec!["c", "b", "a"], vec!["b", "c"], Sorted), + (vec!["c", "b", "a"], vec!["a", "c"], Sorted), + (vec!["c", "b", "a"], vec!["a", "b", "c"], Sorted), + ]; + let n = 300; + let n_distincts = vec![10, 20]; + for n_distinct in n_distincts { + let mut handles = Vec::new(); + for i in 0..n { + let idx = i % test_cases.len(); + let (pb_cols, ob_cols, search_mode) = test_cases[idx].clone(); + let job = tokio::spawn(run_window_test( + make_staggered_batches::(1000, n_distinct, i as u64), + i as u64, + pb_cols, + ob_cols, + search_mode, + )); + handles.push(job); + } + for job in handles { + job.await.unwrap()?; } - Ok(()) } + Ok(()) } fn get_random_function( schema: &SchemaRef, rng: &mut StdRng, is_linear: bool, -) -> (WindowFunction, Vec>, String) { +) -> (WindowFunctionDefinition, Vec>, String) { let mut args = if is_linear { // In linear test for the test version with WindowAggExec we use insert SortExecs to the plan to be able to generate // same result with BoundedWindowAggExec which doesn't use any SortExec. To make result @@ -166,28 +159,28 @@ fn get_random_function( window_fn_map.insert( "sum", ( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![], ), ); window_fn_map.insert( "count", ( - WindowFunction::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), vec![], ), ); window_fn_map.insert( "min", ( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![], ), ); window_fn_map.insert( "max", ( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![], ), ); @@ -198,28 +191,36 @@ fn get_random_function( window_fn_map.insert( "row_number", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::RowNumber), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::RowNumber, + ), vec![], ), ); window_fn_map.insert( "rank", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Rank), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Rank, + ), vec![], ), ); window_fn_map.insert( "dense_rank", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::DenseRank), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::DenseRank, + ), vec![], ), ); window_fn_map.insert( "lead", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Lead, + ), vec![ lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), @@ -229,7 +230,9 @@ fn get_random_function( window_fn_map.insert( "lag", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Lag, + ), vec![ lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), @@ -240,21 +243,27 @@ fn get_random_function( window_fn_map.insert( "first_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::FirstValue, + ), vec![], ), ); window_fn_map.insert( "last_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::LastValue, + ), vec![], ), ); window_fn_map.insert( "nth_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::NthValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::NthValue, + ), vec![lit(ScalarValue::Int64(Some(rng.gen_range(1..10))))], ), ); @@ -262,7 +271,7 @@ fn get_random_function( let rand_fn_idx = rng.gen_range(0..window_fn_map.len()); let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, new_args) = window_fn_map.values().collect::>()[rand_fn_idx]; - if let WindowFunction::AggregateFunction(f) = window_fn { + if let WindowFunctionDefinition::AggregateFunction(f) = window_fn { let a = args[0].clone(); let dt = a.data_type(schema.as_ref()).unwrap(); let sig = f.signature(); @@ -390,13 +399,13 @@ async fn run_window_test( random_seed: u64, partition_by_columns: Vec<&str>, orderby_columns: Vec<&str>, - search_mode: PartitionSearchMode, + search_mode: InputOrderMode, ) -> Result<()> { - let is_linear = !matches!(search_mode, PartitionSearchMode::Sorted); + let is_linear = !matches!(search_mode, InputOrderMode::Sorted); let mut rng = StdRng::seed_from_u64(random_seed); let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); - let ctx = SessionContext::with_config(session_config); + let ctx = SessionContext::new_with_config(session_config); let (window_fn, args, fn_name) = get_random_function(&schema, &mut rng, is_linear); let window_frame = get_random_window_frame(&mut rng, is_linear); @@ -461,7 +470,6 @@ async fn run_window_test( ) .unwrap()], exec1, - schema.clone(), vec![], ) .unwrap(), @@ -484,7 +492,6 @@ async fn run_window_test( ) .unwrap()], exec2, - schema.clone(), vec![], search_mode, ) diff --git a/datafusion/core/tests/memory_limit.rs b/datafusion/core/tests/memory_limit.rs index 1041888b95d9..a98d097856fb 100644 --- a/datafusion/core/tests/memory_limit.rs +++ b/datafusion/core/tests/memory_limit.rs @@ -412,13 +412,13 @@ impl TestCase { let runtime = RuntimeEnv::new(rt_config).unwrap(); // Configure execution - let state = SessionState::with_config_rt(config, Arc::new(runtime)); + let state = SessionState::new_with_config_rt(config, Arc::new(runtime)); let state = match scenario.rules() { Some(rules) => state.with_physical_optimizer_rules(rules), None => state, }; - let ctx = SessionContext::with_state(state); + let ctx = SessionContext::new_with_state(state); ctx.register_table("t", table).expect("registering table"); let query = query.expect("Test error: query not specified"); diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 75ff56a26508..e76b201e0222 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -15,6 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::io::Cursor; +use std::ops::Range; +use std::sync::Arc; +use std::time::SystemTime; + use arrow::array::{ArrayRef, Int64Array, Int8Array, StringArray}; use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; @@ -30,6 +35,7 @@ use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::{collect, Statistics}; use datafusion::prelude::SessionContext; use datafusion_common::Result; + use futures::future::BoxFuture; use futures::{FutureExt, TryFutureExt}; use object_store::memory::InMemory; @@ -39,10 +45,6 @@ use parquet::arrow::async_reader::AsyncFileReader; use parquet::arrow::ArrowWriter; use parquet::errors::ParquetError; use parquet::file::metadata::ParquetMetaData; -use std::io::Cursor; -use std::ops::Range; -use std::sync::Arc; -use std::time::SystemTime; const EXPECTED_USER_DEFINED_METADATA: &str = "some-user-defined-metadata"; @@ -77,13 +79,12 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { // just any url that doesn't point to in memory object store object_store_url: ObjectStoreUrl::local_filesystem(), file_groups: vec![file_groups], + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -186,6 +187,7 @@ async fn store_parquet_in_memory( last_modified: chrono::DateTime::from(SystemTime::now()), size: buf.len(), e_tag: None, + version: None, }; (meta, Bytes::from(buf)) diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 90abbe9e2128..9f94a59a3e59 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -15,19 +15,27 @@ // specific language governing permissions and limitations // under the License. +use std::fs; +use std::sync::Arc; + use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; +use datafusion::datasource::physical_plan::ParquetExec; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::prelude::SessionContext; +use datafusion_common::stats::Precision; use datafusion_execution::cache::cache_manager::CacheManagerConfig; use datafusion_execution::cache::cache_unit; -use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; +use datafusion_execution::cache::cache_unit::{ + DefaultFileStatisticsCache, DefaultListFilesCache, +}; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use std::sync::Arc; + +use tempfile::tempdir; #[tokio::test] async fn load_table_stats_with_session_level_cache() { @@ -35,78 +43,171 @@ async fn load_table_stats_with_session_level_cache() { let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); let table_path = ListingTableUrl::parse(filename).unwrap(); - let (cache1, state1) = get_cache_runtime_state(); + let (cache1, _, state1) = get_cache_runtime_state(); // Create a separate DefaultFileStatisticsCache - let (cache2, state2) = get_cache_runtime_state(); + let (cache2, _, state2) = get_cache_runtime_state(); let opt = ListingOptions::new(Arc::new(ParquetFormat::default())); - let table1 = get_listing_with_cache(&table_path, cache1, &state1, &opt).await; - let table2 = get_listing_with_cache(&table_path, cache2, &state2, &opt).await; + let table1 = get_listing_table(&table_path, Some(cache1), &opt).await; + let table2 = get_listing_table(&table_path, Some(cache2), &opt).await; //Session 1 first time list files - assert_eq!(get_cache_size(&state1), 0); + assert_eq!(get_static_cache_size(&state1), 0); let exec1 = table1.scan(&state1, None, &[], None).await.unwrap(); - assert_eq!(exec1.statistics().num_rows, Some(8)); - assert_eq!(exec1.statistics().total_byte_size, Some(671)); - assert_eq!(get_cache_size(&state1), 1); + assert_eq!(exec1.statistics().unwrap().num_rows, Precision::Exact(8)); + assert_eq!( + exec1.statistics().unwrap().total_byte_size, + Precision::Exact(671) + ); + assert_eq!(get_static_cache_size(&state1), 1); //Session 2 first time list files //check session 1 cache result not show in session 2 + assert_eq!(get_static_cache_size(&state2), 0); + let exec2 = table2.scan(&state2, None, &[], None).await.unwrap(); + assert_eq!(exec2.statistics().unwrap().num_rows, Precision::Exact(8)); assert_eq!( - state2 - .runtime_env() - .cache_manager - .get_file_statistic_cache() - .unwrap() - .len(), - 0 + exec2.statistics().unwrap().total_byte_size, + Precision::Exact(671) ); + assert_eq!(get_static_cache_size(&state2), 1); + + //Session 1 second time list files + //check session 1 cache result not show in session 2 + assert_eq!(get_static_cache_size(&state1), 1); + let exec3 = table1.scan(&state1, None, &[], None).await.unwrap(); + assert_eq!(exec3.statistics().unwrap().num_rows, Precision::Exact(8)); + assert_eq!( + exec3.statistics().unwrap().total_byte_size, + Precision::Exact(671) + ); + // List same file no increase + assert_eq!(get_static_cache_size(&state1), 1); +} + +#[tokio::test] +async fn list_files_with_session_level_cache() { + let p_name = "alltypes_plain.parquet"; + let testdata = datafusion::test_util::parquet_test_data(); + let filename = format!("{}/{}", testdata, p_name); + + let temp_path1 = tempdir() + .unwrap() + .into_path() + .into_os_string() + .into_string() + .unwrap(); + let temp_filename1 = format!("{}/{}", temp_path1, p_name); + + let temp_path2 = tempdir() + .unwrap() + .into_path() + .into_os_string() + .into_string() + .unwrap(); + let temp_filename2 = format!("{}/{}", temp_path2, p_name); + + fs::copy(filename.clone(), temp_filename1).expect("panic"); + fs::copy(filename, temp_filename2).expect("panic"); + + let table_path = ListingTableUrl::parse(temp_path1).unwrap(); + + let (_, _, state1) = get_cache_runtime_state(); + + // Create a separate DefaultFileStatisticsCache + let (_, _, state2) = get_cache_runtime_state(); + + let opt = ListingOptions::new(Arc::new(ParquetFormat::default())); + + let table1 = get_listing_table(&table_path, None, &opt).await; + let table2 = get_listing_table(&table_path, None, &opt).await; + + //Session 1 first time list files + assert_eq!(get_list_file_cache_size(&state1), 0); + let exec1 = table1.scan(&state1, None, &[], None).await.unwrap(); + let parquet1 = exec1.as_any().downcast_ref::().unwrap(); + + assert_eq!(get_list_file_cache_size(&state1), 1); + let fg = &parquet1.base_config().file_groups; + assert_eq!(fg.len(), 1); + assert_eq!(fg.first().unwrap().len(), 1); + + //Session 2 first time list files + //check session 1 cache result not show in session 2 + assert_eq!(get_list_file_cache_size(&state2), 0); let exec2 = table2.scan(&state2, None, &[], None).await.unwrap(); - assert_eq!(exec2.statistics().num_rows, Some(8)); - assert_eq!(exec2.statistics().total_byte_size, Some(671)); - assert_eq!(get_cache_size(&state2), 1); + let parquet2 = exec2.as_any().downcast_ref::().unwrap(); + + assert_eq!(get_list_file_cache_size(&state2), 1); + let fg2 = &parquet2.base_config().file_groups; + assert_eq!(fg2.len(), 1); + assert_eq!(fg2.first().unwrap().len(), 1); //Session 1 second time list files //check session 1 cache result not show in session 2 - assert_eq!(get_cache_size(&state1), 1); + assert_eq!(get_list_file_cache_size(&state1), 1); let exec3 = table1.scan(&state1, None, &[], None).await.unwrap(); - assert_eq!(exec3.statistics().num_rows, Some(8)); - assert_eq!(exec3.statistics().total_byte_size, Some(671)); + let parquet3 = exec3.as_any().downcast_ref::().unwrap(); + + assert_eq!(get_list_file_cache_size(&state1), 1); + let fg = &parquet3.base_config().file_groups; + assert_eq!(fg.len(), 1); + assert_eq!(fg.first().unwrap().len(), 1); // List same file no increase - assert_eq!(get_cache_size(&state1), 1); + assert_eq!(get_list_file_cache_size(&state1), 1); } -async fn get_listing_with_cache( +async fn get_listing_table( table_path: &ListingTableUrl, - cache1: Arc, - state1: &SessionState, + static_cache: Option>, opt: &ListingOptions, ) -> ListingTable { - let schema = opt.infer_schema(state1, table_path).await.unwrap(); + let schema = opt + .infer_schema( + &SessionState::new_with_config_rt( + SessionConfig::default(), + Arc::new(RuntimeEnv::default()), + ), + table_path, + ) + .await + .unwrap(); let config1 = ListingTableConfig::new(table_path.clone()) .with_listing_options(opt.clone()) .with_schema(schema); - ListingTable::try_new(config1) - .unwrap() - .with_cache(Some(cache1)) + let table = ListingTable::try_new(config1).unwrap(); + if let Some(c) = static_cache { + table.with_cache(Some(c)) + } else { + table + } } -fn get_cache_runtime_state() -> (Arc, SessionState) { +fn get_cache_runtime_state() -> ( + Arc, + Arc, + SessionState, +) { let cache_config = CacheManagerConfig::default(); - let cache1 = Arc::new(cache_unit::DefaultFileStatisticsCache::default()); - let cache_config = cache_config.with_files_statistics_cache(Some(cache1.clone())); + let file_static_cache = Arc::new(cache_unit::DefaultFileStatisticsCache::default()); + let list_file_cache = Arc::new(cache_unit::DefaultListFilesCache::default()); + + let cache_config = cache_config + .with_files_statistics_cache(Some(file_static_cache.clone())) + .with_list_files_cache(Some(list_file_cache.clone())); + let rt = Arc::new( RuntimeEnv::new(RuntimeConfig::new().with_cache_manager(cache_config)).unwrap(), ); - let state = SessionContext::with_config_rt(SessionConfig::default(), rt).state(); + let state = SessionContext::new_with_config_rt(SessionConfig::default(), rt).state(); - (cache1, state) + (file_static_cache, list_file_cache, state) } -fn get_cache_size(state1: &SessionState) -> usize { +fn get_static_cache_size(state1: &SessionState) -> usize { state1 .runtime_env() .cache_manager @@ -114,3 +215,12 @@ fn get_cache_size(state1: &SessionState) -> usize { .unwrap() .len() } + +fn get_list_file_cache_size(state1: &SessionState) -> usize { + state1 + .runtime_env() + .cache_manager + .get_list_files_cache() + .unwrap() + .len() +} diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index 885834f93979..f214e8903a4f 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -34,7 +34,7 @@ use datafusion::physical_plan::collect; use datafusion::physical_plan::metrics::MetricsSet; use datafusion::prelude::{col, lit, lit_timestamp_nano, Expr, SessionContext}; use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; -use datafusion_optimizer::utils::{conjunction, disjunction, split_conjunction}; +use datafusion_expr::utils::{conjunction, disjunction, split_conjunction}; use itertools::Itertools; use parquet::file::properties::WriterProperties; use tempfile::TempDir; @@ -507,7 +507,7 @@ impl<'a> TestCase<'a> { ) -> RecordBatch { println!(" scan options: {scan_options:?}"); println!(" reading with filter {filter:?}"); - let ctx = SessionContext::with_config(scan_options.config()); + let ctx = SessionContext::new_with_config(scan_options.config()); let exec = self .test_parquet_file .create_scan(Some(filter.clone())) diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 33a78660ab9d..943f7fdbf4ac 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -44,6 +44,7 @@ mod file_statistics; mod filter_pushdown; mod page_pruning; mod row_group_pruning; +mod schema; mod schema_coercion; #[cfg(test)] @@ -154,7 +155,7 @@ impl ContextWithParquet { let parquet_path = file.path().to_string_lossy(); // now, setup a the file as a data source and run a query against it - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); ctx.register_parquet("t", &parquet_path, ParquetReadOptions::default()) .await diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index 4337259c1e62..23a56bc821d4 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -17,6 +17,7 @@ use crate::parquet::Unit::Page; use crate::parquet::{ContextWithParquet, Scenario}; + use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::listing::PartitionedFile; @@ -30,6 +31,7 @@ use datafusion_common::{ScalarValue, Statistics, ToDFSchema}; use datafusion_expr::{col, lit, Expr}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; + use futures::StreamExt; use object_store::path::Path; use object_store::ObjectMeta; @@ -48,6 +50,7 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), size: metadata.len() as usize, e_tag: None, + version: None, }; let schema = ParquetFormat::default() @@ -71,14 +74,13 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { FileScanConfig { object_store_url, file_groups: vec![vec![partitioned_file]], - file_schema: schema, - statistics: Statistics::default(), + file_schema: schema.clone(), + statistics: Statistics::new_unknown(&schema), // file has 10 cols so index 12 should be month projection: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, Some(predicate), None, @@ -240,10 +242,11 @@ async fn test_prune( expected_row_pages_pruned: Option, expected_results: usize, ) { - let output = ContextWithParquet::new(case_data_type, Page) - .await - .query(sql) - .await; + let output: crate::parquet::TestOutput = + ContextWithParquet::new(case_data_type, Page) + .await + .query(sql) + .await; println!("{}", output.description()); assert_eq!(output.predicate_evaluation_errors(), expected_errors); diff --git a/datafusion/core/tests/sql/parquet_schema.rs b/datafusion/core/tests/parquet/schema.rs similarity index 95% rename from datafusion/core/tests/sql/parquet_schema.rs rename to datafusion/core/tests/parquet/schema.rs index bc1578da2c58..30d4e1193022 100644 --- a/datafusion/core/tests/sql/parquet_schema.rs +++ b/datafusion/core/tests/parquet/schema.rs @@ -22,6 +22,7 @@ use ::parquet::arrow::ArrowWriter; use tempfile::TempDir; use super::*; +use datafusion_common::assert_batches_sorted_eq; #[tokio::test] async fn schema_merge_ignores_metadata_by_default() { @@ -90,7 +91,13 @@ async fn schema_merge_ignores_metadata_by_default() { .await .unwrap(); - let actual = execute_to_batches(&ctx, "SELECT * from t").await; + let actual = ctx + .sql("SELECT * from t") + .await + .unwrap() + .collect() + .await + .unwrap(); assert_batches_sorted_eq!(expected, &actual); assert_no_metadata(&actual); } @@ -151,7 +158,13 @@ async fn schema_merge_can_preserve_metadata() { .await .unwrap(); - let actual = execute_to_batches(&ctx, "SELECT * from t").await; + let actual = ctx + .sql("SELECT * from t") + .await + .unwrap() + .collect() + .await + .unwrap(); assert_batches_sorted_eq!(expected, &actual); assert_metadata(&actual, &expected_metadata); } diff --git a/datafusion/core/tests/parquet/schema_coercion.rs b/datafusion/core/tests/parquet/schema_coercion.rs index f7dace993091..00f3eada496e 100644 --- a/datafusion/core/tests/parquet/schema_coercion.rs +++ b/datafusion/core/tests/parquet/schema_coercion.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use arrow::datatypes::{Field, Schema}; use arrow::record_batch::RecordBatch; use arrow_array::types::Int32Type; @@ -24,14 +26,13 @@ use datafusion::assert_batches_sorted_eq; use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; use datafusion::physical_plan::collect; use datafusion::prelude::SessionContext; -use datafusion_common::Result; -use datafusion_common::Statistics; +use datafusion_common::{Result, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; + use object_store::path::Path; use object_store::ObjectMeta; use parquet::arrow::ArrowWriter; use parquet::file::properties::WriterProperties; -use std::sync::Arc; use tempfile::NamedTempFile; /// Test for reading data from multiple parquet files with different schemas and coercing them into a single schema. @@ -62,13 +63,12 @@ async fn multi_parquet_coercion() { FileScanConfig { object_store_url: ObjectStoreUrl::local_filesystem(), file_groups: vec![file_groups], + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -126,13 +126,12 @@ async fn multi_parquet_coercion_projection() { FileScanConfig { object_store_url: ObjectStoreUrl::local_filesystem(), file_groups: vec![file_groups], + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: Some(vec![1, 0, 2]), limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -193,5 +192,6 @@ pub fn local_unpartitioned_file(path: impl AsRef) -> ObjectMeta last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), size: metadata.len() as usize, e_tag: None, + version: None, } } diff --git a/datafusion/core/tests/path_partition.rs b/datafusion/core/tests/path_partition.rs index 1c08733e3df6..dd8eb52f67c7 100644 --- a/datafusion/core/tests/path_partition.rs +++ b/datafusion/core/tests/path_partition.rs @@ -17,16 +17,13 @@ //! Test queries on partitioned datasets -use arrow::datatypes::DataType; use std::collections::BTreeSet; use std::fs::File; use std::io::{Read, Seek, SeekFrom}; use std::ops::Range; use std::sync::Arc; -use async_trait::async_trait; -use bytes::Bytes; -use chrono::{TimeZone, Utc}; +use arrow::datatypes::DataType; use datafusion::datasource::listing::ListingTableUrl; use datafusion::{ assert_batches_sorted_eq, @@ -39,12 +36,17 @@ use datafusion::{ prelude::SessionContext, test_util::{self, arrow_test_data, parquet_test_data}, }; +use datafusion_common::stats::Precision; use datafusion_common::ScalarValue; + +use async_trait::async_trait; +use bytes::Bytes; +use chrono::{TimeZone, Utc}; use futures::stream; use futures::stream::BoxStream; use object_store::{ path::Path, GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, - ObjectMeta, ObjectStore, + ObjectMeta, ObjectStore, PutOptions, PutResult, }; use tokio::io::AsyncWrite; use url::Url; @@ -458,34 +460,30 @@ async fn parquet_statistics() -> Result<()> { //// NO PROJECTION //// let dataframe = ctx.sql("SELECT * FROM t").await?; let physical_plan = dataframe.create_physical_plan().await?; - assert_eq!(physical_plan.schema().fields().len(), 4); + let schema = physical_plan.schema(); + assert_eq!(schema.fields().len(), 4); - let stat_cols = physical_plan - .statistics() - .column_statistics - .expect("col stats should be defined"); + let stat_cols = physical_plan.statistics()?.column_statistics; assert_eq!(stat_cols.len(), 4); // stats for the first col are read from the parquet file - assert_eq!(stat_cols[0].null_count, Some(3)); + assert_eq!(stat_cols[0].null_count, Precision::Exact(3)); // TODO assert partition column (1,2,3) stats once implemented (#1186) - assert_eq!(stat_cols[1], ColumnStatistics::default()); - assert_eq!(stat_cols[2], ColumnStatistics::default()); - assert_eq!(stat_cols[3], ColumnStatistics::default()); + assert_eq!(stat_cols[1], ColumnStatistics::new_unknown(),); + assert_eq!(stat_cols[2], ColumnStatistics::new_unknown(),); + assert_eq!(stat_cols[3], ColumnStatistics::new_unknown(),); //// WITH PROJECTION //// let dataframe = ctx.sql("SELECT mycol, day FROM t WHERE day='28'").await?; let physical_plan = dataframe.create_physical_plan().await?; - assert_eq!(physical_plan.schema().fields().len(), 2); + let schema = physical_plan.schema(); + assert_eq!(schema.fields().len(), 2); - let stat_cols = physical_plan - .statistics() - .column_statistics - .expect("col stats should be defined"); + let stat_cols = physical_plan.statistics()?.column_statistics; assert_eq!(stat_cols.len(), 2); // stats for the first col are read from the parquet file - assert_eq!(stat_cols[0].null_count, Some(1)); + assert_eq!(stat_cols[0].null_count, Precision::Exact(1)); // TODO assert partition column stats once implemented (#1186) - assert_eq!(stat_cols[1], ColumnStatistics::default()); + assert_eq!(stat_cols[1], ColumnStatistics::new_unknown(),); Ok(()) } @@ -622,7 +620,12 @@ impl MirroringObjectStore { #[async_trait] impl ObjectStore for MirroringObjectStore { - async fn put(&self, _location: &Path, _bytes: Bytes) -> object_store::Result<()> { + async fn put_opts( + &self, + _location: &Path, + _bytes: Bytes, + _opts: PutOptions, + ) -> object_store::Result { unimplemented!() } @@ -655,6 +658,7 @@ impl ObjectStore for MirroringObjectStore { last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), size: metadata.len() as usize, e_tag: None, + version: None, }; Ok(GetResult { @@ -682,26 +686,16 @@ impl ObjectStore for MirroringObjectStore { Ok(data.into()) } - async fn head(&self, location: &Path) -> object_store::Result { - self.files.iter().find(|x| *x == location).unwrap(); - Ok(ObjectMeta { - location: location.clone(), - last_modified: Utc.timestamp_nanos(0), - size: self.file_size as usize, - e_tag: None, - }) - } - async fn delete(&self, _location: &Path) -> object_store::Result<()> { unimplemented!() } - async fn list( + fn list( &self, prefix: Option<&Path>, - ) -> object_store::Result>> { + ) -> BoxStream<'_, object_store::Result> { let prefix = prefix.cloned().unwrap_or_default(); - Ok(Box::pin(stream::iter(self.files.iter().filter_map( + Box::pin(stream::iter(self.files.iter().filter_map( move |location| { // Don't return for exact prefix match let filter = location @@ -715,10 +709,11 @@ impl ObjectStore for MirroringObjectStore { last_modified: Utc.timestamp_nanos(0), size: self.file_size as usize, e_tag: None, + version: None, }) }) }, - )))) + ))) } async fn list_with_delimiter( @@ -752,6 +747,7 @@ impl ObjectStore for MirroringObjectStore { last_modified: Utc.timestamp_nanos(0), size: self.file_size as usize, e_tag: None, + version: None, }; objects.push(object); } diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 5d42936232b5..af6d0d5f4e24 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -17,8 +17,6 @@ use super::*; use datafusion::scalar::ScalarValue; -use datafusion::test_util::scan_empty; -use datafusion_common::cast::as_float64_array; #[tokio::test] async fn csv_query_array_agg_distinct() -> Result<()> { @@ -47,344 +45,24 @@ async fn csv_query_array_agg_distinct() -> Result<()> { let column = actual[0].column(0); assert_eq!(column.len(), 1); - if let ScalarValue::List(Some(mut v), _) = ScalarValue::try_from_array(column, 0)? { - // workaround lack of Ord of ScalarValue - let cmp = |a: &ScalarValue, b: &ScalarValue| { - a.partial_cmp(b).expect("Can compare ScalarValues") - }; - v.sort_by(cmp); - assert_eq!( - *v, - vec![ - ScalarValue::UInt32(Some(1)), - ScalarValue::UInt32(Some(2)), - ScalarValue::UInt32(Some(3)), - ScalarValue::UInt32(Some(4)), - ScalarValue::UInt32(Some(5)) - ] - ); - } else { - unreachable!(); - } - - Ok(()) -} - -#[tokio::test] -async fn aggregate() -> Result<()> { - let results = execute_with_partition("SELECT SUM(c1), SUM(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = [ - "+--------------+--------------+", - "| SUM(test.c1) | SUM(test.c2) |", - "+--------------+--------------+", - "| 60 | 220 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_empty() -> Result<()> { - // The predicate on this query purposely generates no results - let results = - execute_with_partition("SELECT SUM(c1), SUM(c2) FROM test where c1 > 100000", 4) - .await - .unwrap(); - - assert_eq!(results.len(), 1); - - let expected = [ - "+--------------+--------------+", - "| SUM(test.c1) | SUM(test.c2) |", - "+--------------+--------------+", - "| | |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_avg() -> Result<()> { - let results = execute_with_partition("SELECT AVG(c1), AVG(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = [ - "+--------------+--------------+", - "| AVG(test.c1) | AVG(test.c2) |", - "+--------------+--------------+", - "| 1.5 | 5.5 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_max() -> Result<()> { - let results = execute_with_partition("SELECT MAX(c1), MAX(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = [ - "+--------------+--------------+", - "| MAX(test.c1) | MAX(test.c2) |", - "+--------------+--------------+", - "| 3 | 10 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_min() -> Result<()> { - let results = execute_with_partition("SELECT MIN(c1), MIN(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = [ - "+--------------+--------------+", - "| MIN(test.c1) | MIN(test.c2) |", - "+--------------+--------------+", - "| 0 | 1 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped() -> Result<()> { - let results = - execute_with_partition("SELECT c1, SUM(c2) FROM test GROUP BY c1", 4).await?; - - let expected = [ - "+----+--------------+", - "| c1 | SUM(test.c2) |", - "+----+--------------+", - "| 0 | 55 |", - "| 1 | 55 |", - "| 2 | 55 |", - "| 3 | 55 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_avg() -> Result<()> { - let results = - execute_with_partition("SELECT c1, AVG(c2) FROM test GROUP BY c1", 4).await?; - - let expected = [ - "+----+--------------+", - "| c1 | AVG(test.c2) |", - "+----+--------------+", - "| 0 | 5.5 |", - "| 1 | 5.5 |", - "| 2 | 5.5 |", - "| 3 | 5.5 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_empty() -> Result<()> { - let results = execute_with_partition( - "SELECT c1, AVG(c2) FROM test WHERE c1 = 123 GROUP BY c1", - 4, - ) - .await?; - - let expected = [ - "+----+--------------+", - "| c1 | AVG(test.c2) |", - "+----+--------------+", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_max() -> Result<()> { - let results = - execute_with_partition("SELECT c1, MAX(c2) FROM test GROUP BY c1", 4).await?; - - let expected = [ - "+----+--------------+", - "| c1 | MAX(test.c2) |", - "+----+--------------+", - "| 0 | 10 |", - "| 1 | 10 |", - "| 2 | 10 |", - "| 3 | 10 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_min() -> Result<()> { - let results = - execute_with_partition("SELECT c1, MIN(c2) FROM test GROUP BY c1", 4).await?; - - let expected = [ - "+----+--------------+", - "| c1 | MIN(test.c2) |", - "+----+--------------+", - "| 0 | 1 |", - "| 1 | 1 |", - "| 2 | 1 |", - "| 3 | 1 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_min_max_w_custom_window_frames() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_csv(&ctx).await?; - let sql = - "SELECT - MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN 0.3 PRECEDING AND 0.2 FOLLOWING) as min1, - MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN 0.1 PRECEDING AND 0.2 FOLLOWING) as max1 - FROM aggregate_test_100 - ORDER BY C9 - LIMIT 5"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+---------------------+--------------------+", - "| min1 | max1 |", - "+---------------------+--------------------+", - "| 0.01479305307777301 | 0.9965400387585364 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "| 0.01479305307777301 | 0.9706712283358269 |", - "| 0.2667177795079635 | 0.9965400387585364 |", - "| 0.3600766362333053 | 0.9706712283358269 |", - "+---------------------+--------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn aggregate_min_max_w_custom_window_frames_unbounded_start() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_csv(&ctx).await?; - let sql = - "SELECT - MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as min1, - MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as max1 - FROM aggregate_test_100 - ORDER BY C9 - LIMIT 5"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+---------------------+--------------------+", - "| min1 | max1 |", - "+---------------------+--------------------+", - "| 0.01479305307777301 | 0.9965400387585364 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "| 0.01479305307777301 | 0.9965400387585364 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "+---------------------+--------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn aggregate_avg_add() -> Result<()> { - let results = execute_with_partition( - "SELECT AVG(c1), AVG(c1) + 1, AVG(c1) + 2, 1 + AVG(c1) FROM test", - 4, - ) - .await?; - assert_eq!(results.len(), 1); - - let expected = ["+--------------+-------------------------+-------------------------+-------------------------+", - "| AVG(test.c1) | AVG(test.c1) + Int64(1) | AVG(test.c1) + Int64(2) | Int64(1) + AVG(test.c1) |", - "+--------------+-------------------------+-------------------------+-------------------------+", - "| 1.5 | 2.5 | 3.5 | 2.5 |", - "+--------------+-------------------------+-------------------------+-------------------------+"]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn case_sensitive_identifiers_aggregates() { - let ctx = SessionContext::new(); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let expected = [ - "+----------+", - "| MAX(t.i) |", - "+----------+", - "| 1 |", - "+----------+", - ]; - - let results = plan_and_collect(&ctx, "SELECT max(i) FROM t") - .await - .unwrap(); - - assert_batches_sorted_eq!(expected, &results); - - let results = plan_and_collect(&ctx, "SELECT MAX(i) FROM t") - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &results); - - // Using double quotes allows specifying the function name with capitalization - let err = plan_and_collect(&ctx, "SELECT \"MAX\"(i) FROM t") - .await - .unwrap_err(); - assert!(err - .to_string() - .contains("Error during planning: Invalid function 'MAX'")); - - let results = plan_and_collect(&ctx, "SELECT \"max\"(i) FROM t") - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &results); -} - -#[tokio::test] -async fn count_basic() -> Result<()> { - let results = - execute_with_partition("SELECT COUNT(c1), COUNT(c2) FROM test", 1).await?; - assert_eq!(results.len(), 1); + let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&column)?; + let mut scalars = scalar_vec[0].clone(); + // workaround lack of Ord of ScalarValue + let cmp = |a: &ScalarValue, b: &ScalarValue| { + a.partial_cmp(b).expect("Can compare ScalarValues") + }; + scalars.sort_by(cmp); + assert_eq!( + scalars, + vec![ + ScalarValue::UInt32(Some(1)), + ScalarValue::UInt32(Some(2)), + ScalarValue::UInt32(Some(3)), + ScalarValue::UInt32(Some(4)), + ScalarValue::UInt32(Some(5)) + ] + ); - let expected = [ - "+----------------+----------------+", - "| COUNT(test.c1) | COUNT(test.c2) |", - "+----------------+----------------+", - "| 10 | 10 |", - "+----------------+----------------+", - ]; - assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -497,162 +175,6 @@ async fn count_aggregated_cube() -> Result<()> { Ok(()) } -#[tokio::test] -async fn count_multi_expr() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Int32, true), - ])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![ - Some(0), - None, - Some(1), - Some(2), - None, - ])), - Arc::new(Int32Array::from(vec![ - Some(1), - Some(1), - Some(0), - None, - None, - ])), - ], - )?; - - let ctx = SessionContext::new(); - ctx.register_batch("test", data)?; - let sql = "SELECT count(c1, c2) FROM test"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = [ - "+------------------------+", - "| COUNT(test.c1,test.c2) |", - "+------------------------+", - "| 2 |", - "+------------------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn count_multi_expr_group_by() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Int32, true), - Field::new("c3", DataType::Int32, true), - ])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![ - Some(0), - None, - Some(1), - Some(2), - None, - ])), - Arc::new(Int32Array::from(vec![ - Some(1), - Some(1), - Some(0), - None, - None, - ])), - Arc::new(Int32Array::from(vec![ - Some(10), - Some(10), - Some(10), - Some(10), - Some(10), - ])), - ], - )?; - - let ctx = SessionContext::new(); - ctx.register_batch("test", data)?; - let sql = "SELECT c3, count(c1, c2) FROM test group by c3"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = [ - "+----+------------------------+", - "| c3 | COUNT(test.c1,test.c2) |", - "+----+------------------------+", - "| 10 | 2 |", - "+----+------------------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn simple_avg() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - - let batch1 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - )?; - let batch2 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![4, 5]))], - )?; - - let ctx = SessionContext::new(); - - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; - - let result = plan_and_collect(&ctx, "SELECT AVG(a) FROM t").await?; - - let batch = &result[0]; - assert_eq!(1, batch.num_columns()); - assert_eq!(1, batch.num_rows()); - - let values = as_float64_array(batch.column(0)).expect("failed to cast version"); - assert_eq!(values.len(), 1); - // avg(1,2,3,4,5) = 3.0 - assert_eq!(values.value(0), 3.0_f64); - Ok(()) -} - -#[tokio::test] -async fn simple_mean() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - - let batch1 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - )?; - let batch2 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![4, 5]))], - )?; - - let ctx = SessionContext::new(); - - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; - - let result = plan_and_collect(&ctx, "SELECT MEAN(a) FROM t").await?; - - let batch = &result[0]; - assert_eq!(1, batch.num_columns()); - assert_eq!(1, batch.num_rows()); - - let values = as_float64_array(batch.column(0)).expect("failed to cast version"); - assert_eq!(values.len(), 1); - // mean(1,2,3,4,5) = 3.0 - assert_eq!(values.value(0), 3.0_f64); - Ok(()) -} - async fn run_count_distinct_integers_aggregated_scenario( partitions: Vec>, ) -> Result> { @@ -773,35 +295,10 @@ async fn count_distinct_integers_aggregated_multiple_partitions() -> Result<()> Ok(()) } -#[tokio::test] -async fn aggregate_with_alias() -> Result<()> { - let ctx = SessionContext::new(); - let state = ctx.state(); - - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::UInt32, false), - ])); - - let plan = scan_empty(None, schema.as_ref(), None)? - .aggregate(vec![col("c1")], vec![sum(col("c2"))])? - .project(vec![col("c1"), sum(col("c2")).alias("total_salary")])? - .build()?; - - let plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&Arc::new(plan)).await?; - assert_eq!("c1", physical_plan.schema().field(0).name().as_str()); - assert_eq!( - "total_salary", - physical_plan.schema().field(1).name().as_str() - ); - Ok(()) -} - #[tokio::test] async fn test_accumulator_row_accumulator() -> Result<()> { let config = SessionConfig::new(); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, c2, MIN(c13) as min1, MIN(c9) as min2, MAX(c13) as max1, MAX(c9) as max2, AVG(c9) as avg1, MIN(c13) as min3, COUNT(C9) as cnt1, 0.5*SUM(c9-c8) as sum1 diff --git a/datafusion/core/tests/sql/arrow_files.rs b/datafusion/core/tests/sql/arrow_files.rs deleted file mode 100644 index fc90fe3c3464..000000000000 --- a/datafusion/core/tests/sql/arrow_files.rs +++ /dev/null @@ -1,70 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -use datafusion::execution::options::ArrowReadOptions; - -use super::*; - -async fn register_arrow(ctx: &mut SessionContext) { - ctx.register_arrow( - "arrow_simple", - "tests/data/example.arrow", - ArrowReadOptions::default(), - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn arrow_query() { - let mut ctx = SessionContext::new(); - register_arrow(&mut ctx).await; - let sql = "SELECT * FROM arrow_simple"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+----+-----+-------+", - "| f0 | f1 | f2 |", - "+----+-----+-------+", - "| 1 | foo | true |", - "| 2 | bar | |", - "| 3 | baz | false |", - "| 4 | | true |", - "+----+-----+-------+", - ]; - - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn arrow_explain() { - let mut ctx = SessionContext::new(); - register_arrow(&mut ctx).await; - let sql = "EXPLAIN SELECT * FROM arrow_simple"; - let actual = execute(&ctx, sql).await; - let actual = normalize_vec_for_explain(actual); - let expected = vec![ - vec![ - "logical_plan", - "TableScan: arrow_simple projection=[f0, f1, f2]", - ], - vec![ - "physical_plan", - "ArrowExec: file_groups={1 group: [[WORKING_DIR/tests/data/example.arrow]]}, projection=[f0, f1, f2]\n", - ], - ]; - - assert_eq!(expected, actual); -} diff --git a/datafusion/core/tests/sql/create_drop.rs b/datafusion/core/tests/sql/create_drop.rs index aa34552044d4..b1434dddee50 100644 --- a/datafusion/core/tests/sql/create_drop.rs +++ b/datafusion/core/tests/sql/create_drop.rs @@ -26,11 +26,11 @@ async fn create_custom_table() -> Result<()> { let cfg = RuntimeConfig::new(); let env = RuntimeEnv::new(cfg).unwrap(); let ses = SessionConfig::new(); - let mut state = SessionState::with_config_rt(ses, Arc::new(env)); + let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); state .table_factories_mut() .insert("DELTATABLE".to_string(), Arc::new(TestTableFactory {})); - let ctx = SessionContext::with_state(state); + let ctx = SessionContext::new_with_state(state); let sql = "CREATE EXTERNAL TABLE dt STORED AS DELTATABLE LOCATION 's3://bucket/schema/table';"; ctx.sql(sql).await.unwrap(); @@ -48,11 +48,11 @@ async fn create_external_table_with_ddl() -> Result<()> { let cfg = RuntimeConfig::new(); let env = RuntimeEnv::new(cfg).unwrap(); let ses = SessionConfig::new(); - let mut state = SessionState::with_config_rt(ses, Arc::new(env)); + let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); state .table_factories_mut() .insert("MOCKTABLE".to_string(), Arc::new(TestTableFactory {})); - let ctx = SessionContext::with_state(state); + let ctx = SessionContext::new_with_state(state); let sql = "CREATE EXTERNAL TABLE dt (a_id integer, a_str string, a_bool boolean) STORED AS MOCKTABLE LOCATION 'mockprotocol://path/to/table';"; ctx.sql(sql).await.unwrap(); diff --git a/datafusion/core/tests/sql/describe.rs b/datafusion/core/tests/sql/describe.rs deleted file mode 100644 index cd8e79b2c93b..000000000000 --- a/datafusion/core/tests/sql/describe.rs +++ /dev/null @@ -1,72 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::assert_batches_eq; -use datafusion::prelude::*; -use datafusion_common::test_util::parquet_test_data; - -#[tokio::test] -async fn describe_plan() { - let ctx = parquet_context().await; - - let query = "describe alltypes_tiny_pages"; - let results = ctx.sql(query).await.unwrap().collect().await.unwrap(); - - let expected = vec![ - "+-----------------+-----------------------------+-------------+", - "| column_name | data_type | is_nullable |", - "+-----------------+-----------------------------+-------------+", - "| id | Int32 | YES |", - "| bool_col | Boolean | YES |", - "| tinyint_col | Int8 | YES |", - "| smallint_col | Int16 | YES |", - "| int_col | Int32 | YES |", - "| bigint_col | Int64 | YES |", - "| float_col | Float32 | YES |", - "| double_col | Float64 | YES |", - "| date_string_col | Utf8 | YES |", - "| string_col | Utf8 | YES |", - "| timestamp_col | Timestamp(Nanosecond, None) | YES |", - "| year | Int32 | YES |", - "| month | Int32 | YES |", - "+-----------------+-----------------------------+-------------+", - ]; - - assert_batches_eq!(expected, &results); - - // also ensure we plan Describe via SessionState - let state = ctx.state(); - let plan = state.create_logical_plan(query).await.unwrap(); - let df = DataFrame::new(state, plan); - let results = df.collect().await.unwrap(); - - assert_batches_eq!(expected, &results); -} - -/// Return a SessionContext with parquet file registered -async fn parquet_context() -> SessionContext { - let ctx = SessionContext::new(); - let testdata = parquet_test_data(); - ctx.register_parquet( - "alltypes_tiny_pages", - &format!("{testdata}/alltypes_tiny_pages.parquet"), - ParquetReadOptions::default(), - ) - .await - .unwrap(); - ctx -} diff --git a/datafusion/core/tests/sql/displayable.rs b/datafusion/core/tests/sql/displayable.rs deleted file mode 100644 index b736820009cc..000000000000 --- a/datafusion/core/tests/sql/displayable.rs +++ /dev/null @@ -1,57 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use object_store::path::Path; - -use datafusion::prelude::*; -use datafusion_physical_plan::displayable; - -#[tokio::test] -async fn teset_displayable() { - // Hard code target_partitions as it appears in the RepartitionExec output - let config = SessionConfig::new().with_target_partitions(3); - let ctx = SessionContext::with_config(config); - - // register the a table - ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()) - .await - .unwrap(); - - // create a plan to run a SQL query - let dataframe = ctx.sql("SELECT a FROM example WHERE a < 5").await.unwrap(); - let physical_plan = dataframe.create_physical_plan().await.unwrap(); - - // Format using display string in verbose mode - let displayable_plan = displayable(physical_plan.as_ref()); - let plan_string = format!("{}", displayable_plan.indent(true)); - - let working_directory = std::env::current_dir().unwrap(); - let normalized = Path::from_filesystem_path(working_directory).unwrap(); - let plan_string = plan_string.replace(normalized.as_ref(), "WORKING_DIR"); - - assert_eq!("CoalesceBatchesExec: target_batch_size=8192\ - \n FilterExec: a@0 < 5\ - \n RepartitionExec: partitioning=RoundRobinBatch(3), input_partitions=1\ - \n CsvExec: file_groups={1 group: [[WORKING_DIR/tests/data/example.csv]]}, projection=[a], has_header=true", - plan_string.trim()); - - let one_line = format!("{}", displayable_plan.one_line()); - assert_eq!( - "CoalesceBatchesExec: target_batch_size=8192", - one_line.trim() - ); -} diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 06120c01ce86..37f8cefc9080 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -16,6 +16,7 @@ // under the License. use super::*; + use datafusion::config::ConfigOptions; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::metrics::Timestamp; @@ -27,7 +28,7 @@ async fn explain_analyze_baseline_metrics() { let config = SessionConfig::new() .with_target_partitions(3) .with_batch_size(4096); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); register_aggregate_csv_by_sql(&ctx).await; // a query with as many operators as we have metrics for let sql = "EXPLAIN ANALYZE \ @@ -559,7 +560,7 @@ async fn csv_explain_verbose_plans() { // Since the plan contains path that are environmentally // dependant(e.g. full path of the test file), only verify // important content - assert_contains!(&actual, "logical_plan after push_down_projection"); + assert_contains!(&actual, "logical_plan after optimize_projections"); assert_contains!(&actual, "physical_plan"); assert_contains!(&actual, "FilterExec: c2@1 > 10"); assert_contains!(actual, "ProjectionExec: expr=[c1@0 as c1]"); @@ -574,7 +575,7 @@ async fn explain_analyze_runs_optimizers() { // This happens as an optimization pass where count(*) can be // answered using statistics only. - let expected = "EmptyExec: produce_one_row=true"; + let expected = "PlaceholderRowExec"; let sql = "EXPLAIN SELECT count(*) from alltypes_plain"; let actual = execute_to_batches(&ctx, sql).await; @@ -598,7 +599,7 @@ async fn test_physical_plan_display_indent() { let config = SessionConfig::new() .with_target_partitions(9000) .with_batch_size(4096); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); register_aggregate_csv(&ctx).await.unwrap(); let sql = "SELECT c1, MAX(c12), MIN(c12) as the_min \ FROM aggregate_test_100 \ @@ -611,7 +612,7 @@ async fn test_physical_plan_display_indent() { let expected = vec![ "GlobalLimitExec: skip=0, fetch=10", " SortPreservingMergeExec: [the_min@2 DESC], fetch=10", - " SortExec: fetch=10, expr=[the_min@2 DESC]", + " SortExec: TopK(fetch=10), expr=[the_min@2 DESC]", " ProjectionExec: expr=[c1@0 as c1, MAX(aggregate_test_100.c12)@1 as MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)@2 as the_min]", " AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]", " CoalesceBatchesExec: target_batch_size=4096", @@ -642,7 +643,7 @@ async fn test_physical_plan_display_indent_multi_children() { let config = SessionConfig::new() .with_target_partitions(9000) .with_batch_size(4096); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); // ensure indenting works for nodes with multiple children register_aggregate_csv(&ctx).await.unwrap(); let sql = "SELECT c1 \ @@ -777,7 +778,7 @@ async fn csv_explain_analyze_verbose() { async fn explain_logical_plan_only() { let mut config = ConfigOptions::new(); config.explain.logical_plan_only = true; - let ctx = SessionContext::with_config(config.into()); + let ctx = SessionContext::new_with_config(config.into()); let sql = "EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c3)"; let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); @@ -787,7 +788,7 @@ async fn explain_logical_plan_only() { "logical_plan", "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]\ \n SubqueryAlias: t\ - \n Projection: column1\ + \n Projection: \ \n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))" ]]; assert_eq!(expected, actual); @@ -797,15 +798,15 @@ async fn explain_logical_plan_only() { async fn explain_physical_plan_only() { let mut config = ConfigOptions::new(); config.explain.physical_plan_only = true; - let ctx = SessionContext::with_config(config.into()); + let ctx = SessionContext::new_with_config(config.into()); let sql = "EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c3)"; let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); let expected = vec![vec![ "physical_plan", - "ProjectionExec: expr=[2 as COUNT(UInt8(1))]\ - \n EmptyExec: produce_one_row=true\ + "ProjectionExec: expr=[2 as COUNT(*)]\ + \n PlaceholderRowExec\ \n", ]]; assert_eq!(expected, actual); @@ -816,7 +817,7 @@ async fn csv_explain_analyze_with_statistics() { let mut config = ConfigOptions::new(); config.explain.physical_plan_only = true; config.explain.show_statistics = true; - let ctx = SessionContext::with_config(config.into()); + let ctx = SessionContext::new_with_config(config.into()); register_aggregate_csv_by_sql(&ctx).await; let sql = "EXPLAIN ANALYZE SELECT c1 FROM aggregate_test_100"; @@ -826,5 +827,8 @@ async fn csv_explain_analyze_with_statistics() { .to_string(); // should contain scan statistics - assert_contains!(&formatted, ", statistics=[]"); + assert_contains!( + &formatted, + ", statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:)]]" + ); } diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 044b3b57ea90..e8a3d27c089a 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -19,55 +19,6 @@ use datafusion::datasource::empty::EmptyTable; use super::*; -#[tokio::test] -async fn test_boolean_expressions() -> Result<()> { - test_expression!("true", "true"); - test_expression!("false", "false"); - test_expression!("false = false", "true"); - test_expression!("true = false", "false"); - Ok(()) -} - -#[tokio::test] -async fn test_mathematical_expressions_with_null() -> Result<()> { - test_expression!("sqrt(NULL)", "NULL"); - test_expression!("cbrt(NULL)", "NULL"); - test_expression!("sin(NULL)", "NULL"); - test_expression!("cos(NULL)", "NULL"); - test_expression!("tan(NULL)", "NULL"); - test_expression!("asin(NULL)", "NULL"); - test_expression!("acos(NULL)", "NULL"); - test_expression!("atan(NULL)", "NULL"); - test_expression!("sinh(NULL)", "NULL"); - test_expression!("cosh(NULL)", "NULL"); - test_expression!("tanh(NULL)", "NULL"); - test_expression!("asinh(NULL)", "NULL"); - test_expression!("acosh(NULL)", "NULL"); - test_expression!("atanh(NULL)", "NULL"); - test_expression!("floor(NULL)", "NULL"); - test_expression!("ceil(NULL)", "NULL"); - test_expression!("round(NULL)", "NULL"); - test_expression!("trunc(NULL)", "NULL"); - test_expression!("abs(NULL)", "NULL"); - test_expression!("signum(NULL)", "NULL"); - test_expression!("exp(NULL)", "NULL"); - test_expression!("ln(NULL)", "NULL"); - test_expression!("log2(NULL)", "NULL"); - test_expression!("log10(NULL)", "NULL"); - test_expression!("power(NULL, 2)", "NULL"); - test_expression!("power(NULL, NULL)", "NULL"); - test_expression!("power(2, NULL)", "NULL"); - test_expression!("atan2(NULL, NULL)", "NULL"); - test_expression!("atan2(1, NULL)", "NULL"); - test_expression!("atan2(NULL, 1)", "NULL"); - test_expression!("nanvl(NULL, NULL)", "NULL"); - test_expression!("nanvl(1, NULL)", "NULL"); - test_expression!("nanvl(NULL, 1)", "NULL"); - test_expression!("isnan(NULL)", "NULL"); - test_expression!("iszero(NULL)", "NULL"); - Ok(()) -} - #[tokio::test] #[cfg_attr(not(feature = "crypto_expressions"), ignore)] async fn test_encoding_expressions() -> Result<()> { @@ -128,14 +79,6 @@ async fn test_encoding_expressions() -> Result<()> { Ok(()) } -#[should_panic(expected = "Invalid timezone \\\"Foo\\\": 'Foo' is not a valid timezone")] -#[tokio::test] -async fn test_array_cast_invalid_timezone_will_panic() { - let ctx = SessionContext::new(); - let sql = "SELECT arrow_cast('2021-01-02T03:04:00', 'Timestamp(Nanosecond, Some(\"Foo\"))')"; - execute(&ctx, sql).await; -} - #[tokio::test] #[cfg_attr(not(feature = "crypto_expressions"), ignore)] async fn test_crypto_expressions() -> Result<()> { @@ -212,242 +155,6 @@ async fn test_crypto_expressions() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_array_index() -> Result<()> { - // By default PostgreSQL uses a one-based numbering convention for arrays, that is, an array of n elements starts with array[1] and ends with array[n] - test_expression!("([5,4,3,2,1])[1]", "5"); - test_expression!("([5,4,3,2,1])[2]", "4"); - test_expression!("([5,4,3,2,1])[5]", "1"); - test_expression!("([[1, 2], [2, 3], [3,4]])[1]", "[1, 2]"); - test_expression!("([[1, 2], [2, 3], [3,4]])[3]", "[3, 4]"); - test_expression!("([[1, 2], [2, 3], [3,4]])[1][1]", "1"); - test_expression!("([[1, 2], [2, 3], [3,4]])[2][2]", "3"); - test_expression!("([[1, 2], [2, 3], [3,4]])[3][2]", "4"); - // out of bounds - test_expression!("([5,4,3,2,1])[0]", "NULL"); - test_expression!("([5,4,3,2,1])[6]", "NULL"); - // test_expression!("([5,4,3,2,1])[-1]", "NULL"); - test_expression!("([5,4,3,2,1])[100]", "NULL"); - - Ok(()) -} - -#[tokio::test] -async fn test_array_literals() -> Result<()> { - // Named, just another syntax - test_expression!("ARRAY[1,2,3,4,5]", "[1, 2, 3, 4, 5]"); - // Unnamed variant - test_expression!("[1,2,3,4,5]", "[1, 2, 3, 4, 5]"); - test_expression!("[true, false]", "[true, false]"); - test_expression!("['str1', 'str2']", "[str1, str2]"); - test_expression!("[[1,2], [3,4]]", "[[1, 2], [3, 4]]"); - - // TODO: Not supported in parser, uncomment when it will be available - // test_expression!( - // "[]", - // "[]" - // ); - - Ok(()) -} - -#[tokio::test] -async fn test_struct_literals() -> Result<()> { - test_expression!("STRUCT(1,2,3,4,5)", "{c0: 1, c1: 2, c2: 3, c3: 4, c4: 5}"); - test_expression!("STRUCT(Null)", "{c0: }"); - test_expression!("STRUCT(2)", "{c0: 2}"); - test_expression!("STRUCT('1',Null)", "{c0: 1, c1: }"); - test_expression!("STRUCT(true, false)", "{c0: true, c1: false}"); - test_expression!("STRUCT('str1', 'str2')", "{c0: str1, c1: str2}"); - - Ok(()) -} - -#[tokio::test] -async fn binary_bitwise_shift() -> Result<()> { - test_expression!("2 << 10", "2048"); - test_expression!("2048 >> 10", "2"); - test_expression!("2048 << NULL", "NULL"); - test_expression!("2048 >> NULL", "NULL"); - - Ok(()) -} - -#[tokio::test] -async fn test_interval_expressions() -> Result<()> { - // day nano intervals - test_expression!( - "interval '1'", - "0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs" - ); - test_expression!( - "interval '1 second'", - "0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs" - ); - test_expression!( - "interval '500 milliseconds'", - "0 years 0 mons 0 days 0 hours 0 mins 0.500000000 secs" - ); - test_expression!( - "interval '5 second'", - "0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs" - ); - test_expression!( - "interval '0.5 minute'", - "0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs" - ); - // https://github.com/apache/arrow-rs/issues/4424 - // test_expression!( - // "interval '.5 minute'", - // "0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs" - // ); - test_expression!( - "interval '5 minute'", - "0 years 0 mons 0 days 0 hours 5 mins 0.000000000 secs" - ); - test_expression!( - "interval '5 minute 1 second'", - "0 years 0 mons 0 days 0 hours 5 mins 1.000000000 secs" - ); - test_expression!( - "interval '1 hour'", - "0 years 0 mons 0 days 1 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '5 hour'", - "0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 day'", - "0 years 0 mons 1 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 week'", - "0 years 0 mons 7 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '2 weeks'", - "0 years 0 mons 14 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 day 1'", - "0 years 0 mons 1 days 0 hours 0 mins 1.000000000 secs" - ); - test_expression!( - "interval '0.5'", - "0 years 0 mons 0 days 0 hours 0 mins 0.500000000 secs" - ); - test_expression!( - "interval '0.5 day 1'", - "0 years 0 mons 0 days 12 hours 0 mins 1.000000000 secs" - ); - test_expression!( - "interval '0.49 day'", - "0 years 0 mons 0 days 11 hours 45 mins 36.000000000 secs" - ); - test_expression!( - "interval '0.499 day'", - "0 years 0 mons 0 days 11 hours 58 mins 33.600000000 secs" - ); - test_expression!( - "interval '0.4999 day'", - "0 years 0 mons 0 days 11 hours 59 mins 51.360000000 secs" - ); - test_expression!( - "interval '0.49999 day'", - "0 years 0 mons 0 days 11 hours 59 mins 59.136000000 secs" - ); - test_expression!( - "interval '0.49999999999 day'", - "0 years 0 mons 0 days 11 hours 59 mins 59.999999136 secs" - ); - test_expression!( - "interval '5 day'", - "0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs" - ); - // Hour is ignored, this matches PostgreSQL - test_expression!( - "interval '5 day' hour", - "0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds'", - "0 years 0 mons 5 days 4 hours 3 mins 2.100000000 secs" - ); - // month intervals - test_expression!( - "interval '0.5 month'", - "0 years 0 mons 15 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '0.5' month", - "0 years 0 mons 15 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 month'", - "0 years 1 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1' MONTH", - "0 years 1 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '5 month'", - "0 years 5 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '13 month'", - "0 years 13 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '0.5 year'", - "0 years 6 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 year'", - "0 years 12 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 decade'", - "0 years 120 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '2 decades'", - "0 years 240 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 century'", - "0 years 1200 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '2 year'", - "0 years 24 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '2' year", - "0 years 24 mons 0 days 0 hours 0 mins 0.000000000 secs" - ); - // complex - test_expression!( - "interval '1 year 1 day'", - "0 years 12 mons 1 days 0 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 year 1 day 1 hour'", - "0 years 12 mons 1 days 1 hours 0 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 year 1 day 1 hour 1 minute'", - "0 years 12 mons 1 days 1 hours 1 mins 0.000000000 secs" - ); - test_expression!( - "interval '1 year 1 day 1 hour 1 minute 1 second'", - "0 years 12 mons 1 days 1 hours 1 mins 1.000000000 secs" - ); - - Ok(()) -} - #[cfg(feature = "unicode_expressions")] #[tokio::test] async fn test_substring_expr() -> Result<()> { @@ -458,108 +165,6 @@ async fn test_substring_expr() -> Result<()> { Ok(()) } -/// Test string expressions test split into two batches -/// to prevent stack overflow error -#[tokio::test] -async fn test_string_expressions_batch1() -> Result<()> { - test_expression!("ascii('')", "0"); - test_expression!("ascii('x')", "120"); - test_expression!("ascii(NULL)", "NULL"); - test_expression!("bit_length('')", "0"); - test_expression!("bit_length('chars')", "40"); - test_expression!("bit_length('josé')", "40"); - test_expression!("bit_length(NULL)", "NULL"); - test_expression!("btrim(' xyxtrimyyx ', NULL)", "NULL"); - test_expression!("btrim(' xyxtrimyyx ')", "xyxtrimyyx"); - test_expression!("btrim('\n xyxtrimyyx \n')", "\n xyxtrimyyx \n"); - test_expression!("btrim('xyxtrimyyx', 'xyz')", "trim"); - test_expression!("btrim('\nxyxtrimyyx\n', 'xyz\n')", "trim"); - test_expression!("btrim(NULL, 'xyz')", "NULL"); - test_expression!("chr(CAST(120 AS int))", "x"); - test_expression!("chr(CAST(128175 AS int))", "💯"); - test_expression!("chr(CAST(NULL AS int))", "NULL"); - test_expression!("concat('a','b','c')", "abc"); - test_expression!("concat('abcde', 2, NULL, 22)", "abcde222"); - test_expression!("concat(NULL)", ""); - test_expression!("concat_ws(',', 'abcde', 2, NULL, 22)", "abcde,2,22"); - test_expression!("concat_ws('|','a','b','c')", "a|b|c"); - test_expression!("concat_ws('|',NULL)", ""); - test_expression!("concat_ws(NULL,'a',NULL,'b','c')", "NULL"); - test_expression!("concat_ws('|','a',NULL)", "a"); - test_expression!("concat_ws('|','a',NULL,NULL)", "a"); - test_expression!("initcap('')", ""); - test_expression!("initcap('hi THOMAS')", "Hi Thomas"); - test_expression!("initcap(NULL)", "NULL"); - test_expression!("lower('')", ""); - test_expression!("lower('TOM')", "tom"); - test_expression!("lower(NULL)", "NULL"); - test_expression!("ltrim(' zzzytest ', NULL)", "NULL"); - test_expression!("ltrim(' zzzytest ')", "zzzytest "); - test_expression!("ltrim('zzzytest', 'xyz')", "test"); - test_expression!("ltrim(NULL, 'xyz')", "NULL"); - test_expression!("octet_length('')", "0"); - test_expression!("octet_length('chars')", "5"); - test_expression!("octet_length('josé')", "5"); - test_expression!("octet_length(NULL)", "NULL"); - test_expression!("repeat('Pg', 4)", "PgPgPgPg"); - test_expression!("repeat('Pg', CAST(NULL AS INT))", "NULL"); - test_expression!("repeat(NULL, 4)", "NULL"); - test_expression!("replace('abcdefabcdef', 'cd', 'XX')", "abXXefabXXef"); - test_expression!("replace('abcdefabcdef', 'cd', NULL)", "NULL"); - test_expression!("replace('abcdefabcdef', 'notmatch', 'XX')", "abcdefabcdef"); - test_expression!("replace('abcdefabcdef', NULL, 'XX')", "NULL"); - test_expression!("replace(NULL, 'cd', 'XX')", "NULL"); - test_expression!("rtrim(' testxxzx ')", " testxxzx"); - test_expression!("rtrim(' zzzytest ', NULL)", "NULL"); - test_expression!("rtrim('testxxzx', 'xyz')", "test"); - test_expression!("rtrim(NULL, 'xyz')", "NULL"); - Ok(()) -} - -/// Test string expressions test split into two batches -/// to prevent stack overflow error -#[tokio::test] -async fn test_string_expressions_batch2() -> Result<()> { - test_expression!("split_part('abc~@~def~@~ghi', '~@~', 2)", "def"); - test_expression!("split_part('abc~@~def~@~ghi', '~@~', 20)", ""); - test_expression!("split_part(NULL, '~@~', 20)", "NULL"); - test_expression!("split_part('abc~@~def~@~ghi', NULL, 20)", "NULL"); - test_expression!( - "split_part('abc~@~def~@~ghi', '~@~', CAST(NULL AS INT))", - "NULL" - ); - test_expression!("starts_with('alphabet', 'alph')", "true"); - test_expression!("starts_with('alphabet', 'blph')", "false"); - test_expression!("starts_with(NULL, 'blph')", "NULL"); - test_expression!("starts_with('alphabet', NULL)", "NULL"); - test_expression!("to_hex(2147483647)", "7fffffff"); - test_expression!("to_hex(9223372036854775807)", "7fffffffffffffff"); - test_expression!("to_hex(CAST(NULL AS int))", "NULL"); - test_expression!("trim(' tom ')", "tom"); - test_expression!("trim(LEADING ' tom ')", "tom "); - test_expression!("trim(TRAILING ' tom ')", " tom"); - test_expression!("trim(BOTH ' tom ')", "tom"); - test_expression!("trim(LEADING ' ' FROM ' tom ')", "tom "); - test_expression!("trim(TRAILING ' ' FROM ' tom ')", " tom"); - test_expression!("trim(BOTH ' ' FROM ' tom ')", "tom"); - test_expression!("trim(' ' FROM ' tom ')", "tom"); - test_expression!("trim(LEADING 'x' FROM 'xxxtomxxx')", "tomxxx"); - test_expression!("trim(TRAILING 'x' FROM 'xxxtomxxx')", "xxxtom"); - test_expression!("trim(BOTH 'x' FROM 'xxxtomxx')", "tom"); - test_expression!("trim('x' FROM 'xxxtomxx')", "tom"); - test_expression!("trim(LEADING 'xy' FROM 'xyxabcxyzdefxyx')", "abcxyzdefxyx"); - test_expression!("trim(TRAILING 'xy' FROM 'xyxabcxyzdefxyx')", "xyxabcxyzdef"); - test_expression!("trim(BOTH 'xy' FROM 'xyxabcxyzdefxyx')", "abcxyzdef"); - test_expression!("trim('xy' FROM 'xyxabcxyzdefxyx')", "abcxyzdef"); - test_expression!("trim(' tom')", "tom"); - test_expression!("trim('')", ""); - test_expression!("trim('tom ')", "tom"); - test_expression!("upper('')", ""); - test_expression!("upper('tom')", "TOM"); - test_expression!("upper(NULL)", "NULL"); - Ok(()) -} - #[tokio::test] #[cfg_attr(not(feature = "regex_expressions"), ignore)] async fn test_regex_expressions() -> Result<()> { @@ -593,276 +198,6 @@ async fn test_regex_expressions() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_cast_expressions() -> Result<()> { - test_expression!("CAST('0' AS INT)", "0"); - test_expression!("CAST(NULL AS INT)", "NULL"); - test_expression!("TRY_CAST('0' AS INT)", "0"); - test_expression!("TRY_CAST('x' AS INT)", "NULL"); - Ok(()) -} - -#[tokio::test] -#[ignore] -// issue: https://github.com/apache/arrow-datafusion/issues/6596 -async fn test_array_cast_expressions() -> Result<()> { - test_expression!("CAST([1,2,3,4] AS INT[])", "[1, 2, 3, 4]"); - test_expression!( - "CAST([1,2,3,4] AS NUMERIC(10,4)[])", - "[1.0000, 2.0000, 3.0000, 4.0000]" - ); - Ok(()) -} - -#[tokio::test] -async fn test_random_expression() -> Result<()> { - let ctx = create_ctx(); - let sql = "SELECT random() r1"; - let actual = execute(&ctx, sql).await; - let r1 = actual[0][0].parse::().unwrap(); - assert!(0.0 <= r1); - assert!(r1 < 1.0); - Ok(()) -} - -#[tokio::test] -async fn test_uuid_expression() -> Result<()> { - let ctx = create_ctx(); - let sql = "SELECT uuid()"; - let actual = execute(&ctx, sql).await; - let uuid = actual[0][0].parse::().unwrap(); - assert_eq!(uuid.get_version_num(), 4); - Ok(()) -} - -#[tokio::test] -async fn test_extract_date_part() -> Result<()> { - test_expression!("date_part('YEAR', CAST('2000-01-01' AS DATE))", "2000.0"); - test_expression!( - "EXTRACT(year FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "2020.0" - ); - test_expression!("date_part('QUARTER', CAST('2000-01-01' AS DATE))", "1.0"); - test_expression!( - "EXTRACT(quarter FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "3.0" - ); - test_expression!("date_part('MONTH', CAST('2000-01-01' AS DATE))", "1.0"); - test_expression!( - "EXTRACT(month FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "9.0" - ); - test_expression!("date_part('WEEK', CAST('2003-01-01' AS DATE))", "1.0"); - test_expression!( - "EXTRACT(WEEK FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "37.0" - ); - test_expression!("date_part('DAY', CAST('2000-01-01' AS DATE))", "1.0"); - test_expression!( - "EXTRACT(day FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "8.0" - ); - test_expression!("date_part('DOY', CAST('2000-01-01' AS DATE))", "1.0"); - test_expression!( - "EXTRACT(doy FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "252.0" - ); - test_expression!("date_part('DOW', CAST('2000-01-01' AS DATE))", "6.0"); - test_expression!( - "EXTRACT(dow FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "2.0" - ); - test_expression!("date_part('HOUR', CAST('2000-01-01' AS DATE))", "0.0"); - test_expression!( - "EXTRACT(hour FROM to_timestamp('2020-09-08T12:03:03+00:00'))", - "12.0" - ); - test_expression!( - "EXTRACT(minute FROM to_timestamp('2020-09-08T12:12:00+00:00'))", - "12.0" - ); - test_expression!( - "date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00'))", - "12.0" - ); - test_expression!( - "EXTRACT(second FROM to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", - "12.12345678" - ); - test_expression!( - "EXTRACT(millisecond FROM to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", - "12123.45678" - ); - test_expression!( - "EXTRACT(microsecond FROM to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", - "12123456.78" - ); - test_expression!( - "EXTRACT(nanosecond FROM to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", - "1.212345678e10" - ); - test_expression!( - "date_part('second', to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", - "12.12345678" - ); - test_expression!( - "date_part('millisecond', to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", - "12123.45678" - ); - test_expression!( - "date_part('microsecond', to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", - "12123456.78" - ); - test_expression!( - "date_part('nanosecond', to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", - "1.212345678e10" - ); - Ok(()) -} - -#[tokio::test] -async fn test_extract_epoch() -> Result<()> { - test_expression!( - "extract(epoch from '1870-01-01T07:29:10.256'::timestamp)", - "-3155646649.744" - ); - test_expression!( - "extract(epoch from '2000-01-01T00:00:00.000'::timestamp)", - "946684800.0" - ); - test_expression!( - "extract(epoch from to_timestamp('2000-01-01T00:00:00+00:00'))", - "946684800.0" - ); - test_expression!("extract(epoch from NULL::timestamp)", "NULL"); - Ok(()) -} - -#[tokio::test] -async fn test_extract_date_part_func() -> Result<()> { - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "year" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "quarter" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "month" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "week" - ), - "true" - ); - test_expression!( - format!("(date_part('{0}', now()) = EXTRACT({0} FROM now()))", "day"), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "hour" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "minute" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "second" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "millisecond" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "microsecond" - ), - "true" - ); - test_expression!( - format!( - "(date_part('{0}', now()) = EXTRACT({0} FROM now()))", - "nanosecond" - ), - "true" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_in_list_scalar() -> Result<()> { - test_expression!("'a' IN ('a','b')", "true"); - test_expression!("'c' IN ('a','b')", "false"); - test_expression!("'c' NOT IN ('a','b')", "true"); - test_expression!("'a' NOT IN ('a','b')", "false"); - test_expression!("NULL IN ('a','b')", "NULL"); - test_expression!("NULL NOT IN ('a','b')", "NULL"); - test_expression!("'a' IN ('a','b',NULL)", "true"); - test_expression!("'c' IN ('a','b',NULL)", "NULL"); - test_expression!("'a' NOT IN ('a','b',NULL)", "false"); - test_expression!("'c' NOT IN ('a','b',NULL)", "NULL"); - test_expression!("0 IN (0,1,2)", "true"); - test_expression!("3 IN (0,1,2)", "false"); - test_expression!("3 NOT IN (0,1,2)", "true"); - test_expression!("0 NOT IN (0,1,2)", "false"); - test_expression!("NULL IN (0,1,2)", "NULL"); - test_expression!("NULL NOT IN (0,1,2)", "NULL"); - test_expression!("0 IN (0,1,2,NULL)", "true"); - test_expression!("3 IN (0,1,2,NULL)", "NULL"); - test_expression!("0 NOT IN (0,1,2,NULL)", "false"); - test_expression!("3 NOT IN (0,1,2,NULL)", "NULL"); - test_expression!("0.0 IN (0.0,0.1,0.2)", "true"); - test_expression!("0.3 IN (0.0,0.1,0.2)", "false"); - test_expression!("0.3 NOT IN (0.0,0.1,0.2)", "true"); - test_expression!("0.0 NOT IN (0.0,0.1,0.2)", "false"); - test_expression!("NULL IN (0.0,0.1,0.2)", "NULL"); - test_expression!("NULL NOT IN (0.0,0.1,0.2)", "NULL"); - test_expression!("0.0 IN (0.0,0.1,0.2,NULL)", "true"); - test_expression!("0.3 IN (0.0,0.1,0.2,NULL)", "NULL"); - test_expression!("0.0 NOT IN (0.0,0.1,0.2,NULL)", "false"); - test_expression!("0.3 NOT IN (0.0,0.1,0.2,NULL)", "NULL"); - test_expression!("'1' IN ('a','b',1)", "true"); - test_expression!("'2' IN ('a','b',1)", "false"); - test_expression!("'2' NOT IN ('a','b',1)", "true"); - test_expression!("'1' NOT IN ('a','b',1)", "false"); - test_expression!("NULL IN ('a','b',1)", "NULL"); - test_expression!("NULL NOT IN ('a','b',1)", "NULL"); - test_expression!("'1' IN ('a','b',NULL,1)", "true"); - test_expression!("'2' IN ('a','b',NULL,1)", "NULL"); - test_expression!("'1' NOT IN ('a','b',NULL,1)", "false"); - test_expression!("'2' NOT IN ('a','b',NULL,1)", "NULL"); - Ok(()) -} - #[tokio::test] async fn csv_query_nullif_divide_by_0() -> Result<()> { let ctx = SessionContext::new(); @@ -886,18 +221,6 @@ async fn csv_query_nullif_divide_by_0() -> Result<()> { Ok(()) } -#[tokio::test] -async fn csv_query_avg_sqrt() -> Result<()> { - let ctx = create_ctx(); - register_aggregate_csv(&ctx).await?; - let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; - let mut actual = execute(&ctx, sql).await; - actual.sort(); - let expected = vec![vec!["0.6706002946036462"]]; - assert_float_eq(&expected, &actual); - Ok(()) -} - #[tokio::test] async fn nested_subquery() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs index 862d2275afc2..58f0ac21d951 100644 --- a/datafusion/core/tests/sql/group_by.rs +++ b/datafusion/core/tests/sql/group_by.rs @@ -82,7 +82,7 @@ async fn group_by_limit() -> Result<()> { let physical_plan = dataframe.create_physical_plan().await?; let mut expected_physical_plan = r#" GlobalLimitExec: skip=0, fetch=4 - SortExec: fetch=4, expr=[MAX(traces.ts)@1 DESC] + SortExec: TopK(fetch=4), expr=[MAX(traces.ts)@1 DESC] AggregateExec: mode=Single, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.ts)], lim=[4] "#.trim().to_string(); let actual_phys_plan = @@ -149,7 +149,7 @@ async fn create_groupby_context(tmp_dir: &TempDir) -> Result { } let cfg = SessionConfig::new().with_target_partitions(1); - let ctx = SessionContext::with_config(cfg); + let ctx = SessionContext::new_with_config(cfg); ctx.register_csv( "traces", tmp_dir.path().to_str().unwrap(), @@ -231,13 +231,13 @@ async fn group_by_dictionary() { .expect("ran plan correctly"); let expected = [ - "+-------+------------------------+", - "| t.val | COUNT(DISTINCT t.dict) |", - "+-------+------------------------+", - "| 1 | 2 |", - "| 2 | 2 |", - "| 4 | 1 |", - "+-------+------------------------+", + "+-----+------------------------+", + "| val | COUNT(DISTINCT t.dict) |", + "+-----+------------------------+", + "| 1 | 2 |", + "| 2 | 2 |", + "| 4 | 1 |", + "+-----+------------------------+", ]; assert_batches_sorted_eq!(expected, &results); } diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index d08f09d3b6e1..d1f270b540b5 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use datafusion::datasource::stream::{StreamConfig, StreamTable}; use datafusion::test_util::register_unbounded_file_with_ordering; use super::*; @@ -81,7 +82,7 @@ async fn null_aware_left_anti_join() -> Result<()> { #[tokio::test] async fn join_change_in_planner() -> Result<()> { let config = SessionConfig::new().with_target_partitions(8); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let tmp_dir = TempDir::new().unwrap(); let left_file_path = tmp_dir.path().join("left.csv"); File::create(left_file_path.clone()).unwrap(); @@ -105,9 +106,7 @@ async fn join_change_in_planner() -> Result<()> { &left_file_path, "left", file_sort_order.clone(), - true, - ) - .await?; + )?; let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone()).unwrap(); register_unbounded_file_with_ordering( @@ -116,9 +115,7 @@ async fn join_change_in_planner() -> Result<()> { &right_file_path, "right", file_sort_order, - true, - ) - .await?; + )?; let sql = "SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10"; let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; @@ -152,7 +149,7 @@ async fn join_change_in_planner() -> Result<()> { #[tokio::test] async fn join_change_in_planner_without_sort() -> Result<()> { let config = SessionConfig::new().with_target_partitions(8); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let tmp_dir = TempDir::new()?; let left_file_path = tmp_dir.path().join("left.csv"); File::create(left_file_path.clone())?; @@ -160,20 +157,13 @@ async fn join_change_in_planner_without_sort() -> Result<()> { Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); - ctx.register_csv( - "left", - left_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let left = StreamConfig::new_file(schema.clone(), left_file_path); + ctx.register_table("left", Arc::new(StreamTable::new(Arc::new(left))))?; + let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone())?; - ctx.register_csv( - "right", - right_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let right = StreamConfig::new_file(schema, right_file_path); + ctx.register_table("right", Arc::new(StreamTable::new(Arc::new(right))))?; let sql = "SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10"; let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; @@ -209,7 +199,7 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { let config = SessionConfig::new() .with_target_partitions(8) .with_allow_symmetric_joins_without_pruning(false); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let tmp_dir = TempDir::new()?; let left_file_path = tmp_dir.path().join("left.csv"); File::create(left_file_path.clone())?; @@ -217,20 +207,12 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); - ctx.register_csv( - "left", - left_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let left = StreamConfig::new_file(schema.clone(), left_file_path); + ctx.register_table("left", Arc::new(StreamTable::new(Arc::new(left))))?; let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone())?; - ctx.register_csv( - "right", - right_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let right = StreamConfig::new_file(schema.clone(), right_file_path); + ctx.register_table("right", Arc::new(StreamTable::new(Arc::new(right))))?; let df = ctx.sql("SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10").await?; match df.create_physical_plan().await { Ok(_) => panic!("Expecting error."), diff --git a/datafusion/core/tests/sql/limit.rs b/datafusion/core/tests/sql/limit.rs deleted file mode 100644 index 1c8ea4fd3468..000000000000 --- a/datafusion/core/tests/sql/limit.rs +++ /dev/null @@ -1,101 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::*; - -#[tokio::test] -async fn limit() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = create_ctx_with_partition(&tmp_dir, 1).await?; - ctx.register_table("t", table_with_sequence(1, 1000).unwrap()) - .unwrap(); - - let results = plan_and_collect(&ctx, "SELECT i FROM t ORDER BY i DESC limit 3") - .await - .unwrap(); - - #[rustfmt::skip] - let expected = ["+------+", - "| i |", - "+------+", - "| 1000 |", - "| 999 |", - "| 998 |", - "+------+"]; - - assert_batches_eq!(expected, &results); - - let results = plan_and_collect(&ctx, "SELECT i FROM t ORDER BY i limit 3") - .await - .unwrap(); - - #[rustfmt::skip] - let expected = ["+---+", - "| i |", - "+---+", - "| 1 |", - "| 2 |", - "| 3 |", - "+---+"]; - - assert_batches_eq!(expected, &results); - - let results = plan_and_collect(&ctx, "SELECT i FROM t limit 3") - .await - .unwrap(); - - // the actual rows are not guaranteed, so only check the count (should be 3) - let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); - assert_eq!(num_rows, 3); - - Ok(()) -} - -#[tokio::test] -async fn limit_multi_partitions() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = create_ctx_with_partition(&tmp_dir, 1).await?; - - let partitions = vec![ - vec![make_partition(0)], - vec![make_partition(1)], - vec![make_partition(2)], - vec![make_partition(3)], - vec![make_partition(4)], - vec![make_partition(5)], - ]; - let schema = partitions[0][0].schema(); - let provider = Arc::new(MemTable::try_new(schema, partitions).unwrap()); - - ctx.register_table("t", provider).unwrap(); - - // select all rows - let results = plan_and_collect(&ctx, "SELECT i FROM t").await.unwrap(); - - let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); - assert_eq!(num_rows, 15); - - for limit in 1..10 { - let query = format!("SELECT i FROM t limit {limit}"); - let results = plan_and_collect(&ctx, &query).await.unwrap(); - - let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); - assert_eq!(num_rows, limit, "mismatch with query {query}"); - } - - Ok(()) -} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 7d175b65260f..849d85dec6bf 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::convert::TryFrom; use std::sync::Arc; use arrow::{ @@ -26,6 +25,7 @@ use chrono::prelude::*; use chrono::Duration; use datafusion::datasource::TableProvider; +use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{Aggregate, LogicalPlan, TableScan}; use datafusion::physical_plan::metrics::MetricValue; use datafusion::physical_plan::ExecutionPlan; @@ -34,15 +34,9 @@ use datafusion::prelude::*; use datafusion::test_util; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; use datafusion::{datasource::MemTable, physical_plan::collect}; -use datafusion::{ - error::{DataFusionError, Result}, - physical_plan::ColumnarValue, -}; use datafusion::{execution::context::SessionContext, physical_plan::displayable}; -use datafusion_common::cast::as_float64_array; use datafusion_common::plan_err; use datafusion_common::{assert_contains, assert_not_contains}; -use datafusion_expr::Volatility; use object_store::path::Path; use std::fs::File; use std::io::Write; @@ -78,84 +72,27 @@ macro_rules! test_expression { } pub mod aggregates; -pub mod arrow_files; -#[cfg(feature = "avro")] pub mod create_drop; pub mod csv_files; -pub mod describe; -pub mod displayable; pub mod explain_analyze; pub mod expr; pub mod group_by; pub mod joins; -pub mod limit; pub mod order; -pub mod parquet; -pub mod parquet_schema; pub mod partitioned_csv; pub mod predicates; -pub mod projection; pub mod references; pub mod repartition; pub mod select; mod sql_api; -pub mod subqueries; pub mod timestamp; -pub mod udf; - -fn assert_float_eq(expected: &[Vec], received: &[Vec]) -where - T: AsRef, -{ - expected - .iter() - .flatten() - .zip(received.iter().flatten()) - .for_each(|(l, r)| { - let (l, r) = ( - l.as_ref().parse::().unwrap(), - r.as_str().parse::().unwrap(), - ); - if l.is_nan() || r.is_nan() { - assert!(l.is_nan() && r.is_nan()); - } else if (l - r).abs() > 2.0 * f64::EPSILON { - panic!("{l} != {r}") - } - }); -} - -fn create_ctx() -> SessionContext { - let ctx = SessionContext::new(); - - // register a custom UDF - ctx.register_udf(create_udf( - "custom_sqrt", - vec![DataType::Float64], - Arc::new(DataType::Float64), - Volatility::Immutable, - Arc::new(custom_sqrt), - )); - - ctx -} - -fn custom_sqrt(args: &[ColumnarValue]) -> Result { - let arg = &args[0]; - if let ColumnarValue::Array(v) = arg { - let input = as_float64_array(v).expect("cast failed"); - let array: Float64Array = input.iter().map(|v| v.map(|x| x.sqrt())).collect(); - Ok(ColumnarValue::Array(Arc::new(array))) - } else { - unimplemented!() - } -} fn create_join_context( column_left: &str, column_right: &str, repartition_joins: bool, ) -> Result { - let ctx = SessionContext::with_config( + let ctx = SessionContext::new_with_config( SessionConfig::new() .with_repartition_joins(repartition_joins) .with_target_partitions(2) @@ -210,7 +147,7 @@ fn create_left_semi_anti_join_context_with_null_ids( column_right: &str, repartition_joins: bool, ) -> Result { - let ctx = SessionContext::with_config( + let ctx = SessionContext::new_with_config( SessionConfig::new() .with_repartition_joins(repartition_joins) .with_target_partitions(2) @@ -512,23 +449,6 @@ async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { ); } -async fn register_aggregate_simple_csv(ctx: &SessionContext) -> Result<()> { - // It's not possible to use aggregate_test_100 as it doesn't have enough similar values to test grouping on floats. - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Float32, false), - Field::new("c2", DataType::Float64, false), - Field::new("c3", DataType::Boolean, false), - ])); - - ctx.register_csv( - "aggregate_simple", - "tests/data/aggregate_simple.csv", - CsvReadOptions::new().schema(&schema), - ) - .await?; - Ok(()) -} - async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { let testdata = datafusion::test_util::arrow_test_data(); let schema = test_util::aggr_test_schema(); @@ -577,7 +497,8 @@ async fn create_ctx_with_partition( tmp_dir: &TempDir, partition_count: usize, ) -> Result { - let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); + let ctx = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(8)); let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?; @@ -621,18 +542,6 @@ fn populate_csv_partitions( Ok(schema) } -/// Return a RecordBatch with a single Int32 array with values (0..sz) -pub fn make_partition(sz: i32) -> RecordBatch { - let seq_start = 0; - let seq_end = sz; - let values = (seq_start..seq_end).collect::>(); - let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); - let arr = Arc::new(Int32Array::from(values)); - let arr = arr as ArrayRef; - - RecordBatch::try_new(schema, vec![arr]).unwrap() -} - /// Specialised String representation fn col_str(column: &ArrayRef, row_index: usize) -> String { // NullArray::is_null() does not work on NullArray. diff --git a/datafusion/core/tests/sql/order.rs b/datafusion/core/tests/sql/order.rs index c5497b4cc0f9..6e3f6319e119 100644 --- a/datafusion/core/tests/sql/order.rs +++ b/datafusion/core/tests/sql/order.rs @@ -180,7 +180,7 @@ async fn test_issue5970_mini() -> Result<()> { let config = SessionConfig::new() .with_target_partitions(2) .with_repartition_sorts(true); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let sql = " WITH m0(t) AS ( @@ -209,16 +209,16 @@ ORDER BY 1, 2; " AggregateExec: mode=FinalPartitioned, gby=[Int64(0)@0 as Int64(0), t@1 as t], aggr=[]", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=Hash([Int64(0)@0, t@1], 2), input_partitions=2", - " AggregateExec: mode=Partial, gby=[0 as Int64(0), t@0 as t], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", + " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", + " AggregateExec: mode=Partial, gby=[0 as Int64(0), t@0 as t], aggr=[]", " ProjectionExec: expr=[column1@0 as t]", " ValuesExec", " ProjectionExec: expr=[Int64(1)@0 as m, t@1 as t]", " AggregateExec: mode=FinalPartitioned, gby=[Int64(1)@0 as Int64(1), t@1 as t], aggr=[]", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=Hash([Int64(1)@0, t@1], 2), input_partitions=2", - " AggregateExec: mode=Partial, gby=[1 as Int64(1), t@0 as t], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", + " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", + " AggregateExec: mode=Partial, gby=[1 as Int64(1), t@0 as t], aggr=[]", " ProjectionExec: expr=[column1@0 as t]", " ValuesExec", ]; diff --git a/datafusion/core/tests/sql/parquet.rs b/datafusion/core/tests/sql/parquet.rs deleted file mode 100644 index c2844a2b762a..000000000000 --- a/datafusion/core/tests/sql/parquet.rs +++ /dev/null @@ -1,383 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{fs, path::Path}; - -use ::parquet::arrow::ArrowWriter; -use datafusion::{datasource::listing::ListingOptions, execution::options::ReadOptions}; -use datafusion_common::cast::{as_list_array, as_primitive_array, as_string_array}; -use tempfile::TempDir; - -use super::*; - -#[tokio::test] -async fn parquet_query() { - let ctx = SessionContext::new(); - register_alltypes_parquet(&ctx).await; - // NOTE that string_col is actually a binary column and does not have the UTF8 logical type - // so we need an explicit cast - let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+----+---------------------------+", - "| id | alltypes_plain.string_col |", - "+----+---------------------------+", - "| 4 | 0 |", - "| 5 | 1 |", - "| 6 | 0 |", - "| 7 | 1 |", - "| 2 | 0 |", - "| 3 | 1 |", - "| 0 | 0 |", - "| 1 | 1 |", - "+----+---------------------------+", - ]; - - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -/// Test that if sort order is specified in ListingOptions, the sort -/// expressions make it all the way down to the ParquetExec -async fn parquet_with_sort_order_specified() { - let parquet_read_options = ParquetReadOptions::default(); - let session_config = SessionConfig::new().with_target_partitions(2); - - // The sort order is not specified - let options_no_sort = parquet_read_options.to_listing_options(&session_config); - - // The sort order is specified (not actually correct in this case) - let file_sort_order = [col("string_col"), col("int_col")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>(); - - let options_sort = parquet_read_options - .to_listing_options(&session_config) - .with_file_sort_order(vec![file_sort_order]); - - // This string appears in ParquetExec if the output ordering is - // specified - let expected_output_ordering = - "output_ordering=[string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST]"; - - // when sort not specified, should not appear in the explain plan - let num_files = 1; - assert_not_contains!( - run_query_with_options(options_no_sort, num_files).await, - expected_output_ordering - ); - - // when sort IS specified, SHOULD appear in the explain plan - let num_files = 1; - assert_contains!( - run_query_with_options(options_sort.clone(), num_files).await, - expected_output_ordering - ); - - // when sort IS specified, but there are too many files (greater - // than the number of partitions) sort should not appear - let num_files = 3; - assert_not_contains!( - run_query_with_options(options_sort, num_files).await, - expected_output_ordering - ); -} - -/// Runs a limit query against a parquet file that was registered from -/// options on num_files copies of all_types_plain.parquet -async fn run_query_with_options(options: ListingOptions, num_files: usize) -> String { - let ctx = SessionContext::new(); - - let testdata = datafusion::test_util::parquet_test_data(); - let file_path = format!("{testdata}/alltypes_plain.parquet"); - - // Create a directory of parquet files with names - // 0.parquet - // 1.parquet - let tmpdir = TempDir::new().unwrap(); - for i in 0..num_files { - let target_file = tmpdir.path().join(format!("{i}.parquet")); - println!("Copying {file_path} to {target_file:?}"); - std::fs::copy(&file_path, target_file).unwrap(); - } - - let provided_schema = None; - let sql_definition = None; - ctx.register_listing_table( - "t", - tmpdir.path().to_string_lossy(), - options.clone(), - provided_schema, - sql_definition, - ) - .await - .unwrap(); - - let batches = ctx.sql("explain select int_col, string_col from t order by string_col, int_col limit 10") - .await - .expect("planing worked") - .collect() - .await - .expect("execution worked"); - - arrow::util::pretty::pretty_format_batches(&batches) - .unwrap() - .to_string() -} - -#[tokio::test] -async fn fixed_size_binary_columns() { - let ctx = SessionContext::new(); - ctx.register_parquet( - "t0", - "tests/data/test_binary.parquet", - ParquetReadOptions::default(), - ) - .await - .unwrap(); - let sql = "SELECT ids FROM t0 ORDER BY ids"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - for batch in results { - assert_eq!(466, batch.num_rows()); - assert_eq!(1, batch.num_columns()); - } -} - -#[tokio::test] -async fn window_fn_timestamp_tz() { - let ctx = SessionContext::new(); - ctx.register_parquet( - "t0", - "tests/data/timestamp_with_tz.parquet", - ParquetReadOptions::default(), - ) - .await - .unwrap(); - - let sql = "SELECT count, LAG(timestamp, 1) OVER (ORDER BY timestamp) FROM t0"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - - let mut num_rows = 0; - for batch in results { - num_rows += batch.num_rows(); - assert_eq!(2, batch.num_columns()); - - let ty = batch.column(0).data_type().clone(); - assert_eq!(DataType::Int64, ty); - - let ty = batch.column(1).data_type().clone(); - assert_eq!( - DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), - ty - ); - } - - assert_eq!(131072, num_rows); -} - -#[tokio::test] -async fn parquet_single_nan_schema() { - let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "single_nan", - &format!("{testdata}/single_nan.parquet"), - ParquetReadOptions::default(), - ) - .await - .unwrap(); - let sql = "SELECT mycol FROM single_nan"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - for batch in results { - assert_eq!(1, batch.num_rows()); - assert_eq!(1, batch.num_columns()); - } -} - -#[tokio::test] -#[ignore = "Test ignored, will be enabled as part of the nested Parquet reader"] -async fn parquet_list_columns() { - let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "list_columns", - &format!("{testdata}/list_columns.parquet"), - ParquetReadOptions::default(), - ) - .await - .unwrap(); - - let schema = Arc::new(Schema::new(vec![ - Field::new_list( - "int64_list", - Field::new("item", DataType::Int64, true), - true, - ), - Field::new_list("utf8_list", Field::new("item", DataType::Utf8, true), true), - ])); - - let sql = "SELECT int64_list, utf8_list FROM list_columns"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - - // int64_list utf8_list - // 0 [1, 2, 3] [abc, efg, hij] - // 1 [None, 1] None - // 2 [4] [efg, None, hij, xyz] - - assert_eq!(1, results.len()); - let batch = &results[0]; - assert_eq!(3, batch.num_rows()); - assert_eq!(2, batch.num_columns()); - assert_eq!(schema, batch.schema()); - - let int_list_array = as_list_array(batch.column(0)).unwrap(); - let utf8_list_array = as_list_array(batch.column(1)).unwrap(); - - assert_eq!( - as_primitive_array::(&int_list_array.value(0)).unwrap(), - &PrimitiveArray::::from(vec![Some(1), Some(2), Some(3),]) - ); - - assert_eq!( - as_string_array(&utf8_list_array.value(0)).unwrap(), - &StringArray::try_from(vec![Some("abc"), Some("efg"), Some("hij"),]).unwrap() - ); - - assert_eq!( - as_primitive_array::(&int_list_array.value(1)).unwrap(), - &PrimitiveArray::::from(vec![None, Some(1),]) - ); - - assert!(utf8_list_array.is_null(1)); - - assert_eq!( - as_primitive_array::(&int_list_array.value(2)).unwrap(), - &PrimitiveArray::::from(vec![Some(4),]) - ); - - let result = utf8_list_array.value(2); - let result = as_string_array(&result).unwrap(); - - assert_eq!(result.value(0), "efg"); - assert!(result.is_null(1)); - assert_eq!(result.value(2), "hij"); - assert_eq!(result.value(3), "xyz"); -} - -#[tokio::test] -async fn parquet_query_with_max_min() { - let tmp_dir = TempDir::new().unwrap(); - let table_dir = tmp_dir.path().join("parquet_test"); - let table_path = Path::new(&table_dir); - - let fields = vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Utf8, true), - Field::new("c3", DataType::Int64, true), - Field::new("c4", DataType::Date32, true), - ]; - - let schema = Arc::new(Schema::new(fields.clone())); - - if let Ok(()) = fs::create_dir(table_path) { - let filename = "foo.parquet"; - let path = table_path.join(filename); - let file = fs::File::create(path).unwrap(); - let mut writer = - ArrowWriter::try_new(file.try_clone().unwrap(), schema.clone(), None) - .unwrap(); - - // create mock record batch - let c1s = Arc::new(Int32Array::from(vec![1, 2, 3])); - let c2s = Arc::new(StringArray::from(vec!["aaa", "bbb", "ccc"])); - let c3s = Arc::new(Int64Array::from(vec![100, 200, 300])); - let c4s = Arc::new(Date32Array::from(vec![Some(1), Some(2), Some(3)])); - let rec_batch = - RecordBatch::try_new(schema.clone(), vec![c1s, c2s, c3s, c4s]).unwrap(); - - writer.write(&rec_batch).unwrap(); - writer.close().unwrap(); - } - - // query parquet - let ctx = SessionContext::new(); - - ctx.register_parquet( - "foo", - &format!("{}/foo.parquet", table_dir.to_str().unwrap()), - ParquetReadOptions::default(), - ) - .await - .unwrap(); - - let sql = "SELECT max(c1) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MAX(foo.c1) |", - "+-------------+", - "| 3 |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); - - let sql = "SELECT min(c2) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MIN(foo.c2) |", - "+-------------+", - "| aaa |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); - - let sql = "SELECT max(c3) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MAX(foo.c3) |", - "+-------------+", - "| 300 |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); - - let sql = "SELECT min(c4) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MIN(foo.c4) |", - "+-------------+", - "| 1970-01-02 |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); -} diff --git a/datafusion/core/tests/sql/partitioned_csv.rs b/datafusion/core/tests/sql/partitioned_csv.rs index 98cb3b189361..b77557a66cd8 100644 --- a/datafusion/core/tests/sql/partitioned_csv.rs +++ b/datafusion/core/tests/sql/partitioned_csv.rs @@ -19,31 +19,13 @@ use std::{io::Write, sync::Arc}; -use arrow::{ - datatypes::{DataType, Field, Schema, SchemaRef}, - record_batch::RecordBatch, -}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::{ error::Result, prelude::{CsvReadOptions, SessionConfig, SessionContext}, }; use tempfile::TempDir; -/// Execute SQL and return results -async fn plan_and_collect( - ctx: &mut SessionContext, - sql: &str, -) -> Result> { - ctx.sql(sql).await?.collect().await -} - -/// Execute SQL and return results -pub async fn execute(sql: &str, partition_count: usize) -> Result> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, partition_count).await?; - plan_and_collect(&mut ctx, sql).await -} - /// Generate CSV partitions within the supplied directory fn populate_csv_partitions( tmp_dir: &TempDir, @@ -78,7 +60,8 @@ pub async fn create_ctx( tmp_dir: &TempDir, partition_count: usize, ) -> Result { - let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); + let ctx = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(8)); let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?; diff --git a/datafusion/core/tests/sql/projection.rs b/datafusion/core/tests/sql/projection.rs deleted file mode 100644 index b31cb34f5210..000000000000 --- a/datafusion/core/tests/sql/projection.rs +++ /dev/null @@ -1,373 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::datasource::provider_as_source; -use datafusion::test_util::scan_empty; -use datafusion_expr::{when, LogicalPlanBuilder, UNNAMED_TABLE}; -use tempfile::TempDir; - -use super::*; - -#[tokio::test] -async fn projection_same_fields() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select (1+1) as a from (select 1 as a) as b;"; - let actual = execute_to_batches(&ctx, sql).await; - - #[rustfmt::skip] - let expected = ["+---+", - "| a |", - "+---+", - "| 2 |", - "+---+"]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn projection_type_alias() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_simple_csv(&ctx).await?; - - // Query that aliases one column to the name of a different column - // that also has a different type (c1 == float32, c3 == boolean) - let sql = "SELECT c1 as c3 FROM aggregate_simple ORDER BY c3 LIMIT 2"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = [ - "+---------+", - "| c3 |", - "+---------+", - "| 0.00001 |", - "| 0.00002 |", - "+---------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_by_avg_with_projection() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_csv(&ctx).await?; - let sql = "SELECT avg(c12), c1 FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-----------------------------+----+", - "| AVG(aggregate_test_100.c12) | c1 |", - "+-----------------------------+----+", - "| 0.41040709263815384 | b |", - "| 0.48600669271341534 | e |", - "| 0.48754517466109415 | a |", - "| 0.48855379387549824 | d |", - "| 0.6600456536439784 | c |", - "+-----------------------------+----+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn parallel_projection() -> Result<()> { - let partition_count = 4; - let results = - partitioned_csv::execute("SELECT c1, c2 FROM test", partition_count).await?; - - let expected = vec![ - "+----+----+", - "| c1 | c2 |", - "+----+----+", - "| 3 | 1 |", - "| 3 | 2 |", - "| 3 | 3 |", - "| 3 | 4 |", - "| 3 | 5 |", - "| 3 | 6 |", - "| 3 | 7 |", - "| 3 | 8 |", - "| 3 | 9 |", - "| 3 | 10 |", - "| 2 | 1 |", - "| 2 | 2 |", - "| 2 | 3 |", - "| 2 | 4 |", - "| 2 | 5 |", - "| 2 | 6 |", - "| 2 | 7 |", - "| 2 | 8 |", - "| 2 | 9 |", - "| 2 | 10 |", - "| 1 | 1 |", - "| 1 | 2 |", - "| 1 | 3 |", - "| 1 | 4 |", - "| 1 | 5 |", - "| 1 | 6 |", - "| 1 | 7 |", - "| 1 | 8 |", - "| 1 | 9 |", - "| 1 | 10 |", - "| 0 | 1 |", - "| 0 | 2 |", - "| 0 | 3 |", - "| 0 | 4 |", - "| 0 | 5 |", - "| 0 | 6 |", - "| 0 | 7 |", - "| 0 | 8 |", - "| 0 | 9 |", - "| 0 | 10 |", - "+----+----+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn subquery_alias_case_insensitive() -> Result<()> { - let partition_count = 1; - let results = - partitioned_csv::execute("SELECT V1.c1, v1.C2 FROM (SELECT test.C1, TEST.c2 FROM test) V1 ORDER BY v1.c1, V1.C2 LIMIT 1", partition_count).await?; - - let expected = [ - "+----+----+", - "| c1 | c2 |", - "+----+----+", - "| 0 | 1 |", - "+----+----+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn projection_on_table_scan() -> Result<()> { - let tmp_dir = TempDir::new()?; - let partition_count = 4; - let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; - - let table = ctx.table("test").await?; - let logical_plan = LogicalPlanBuilder::from(table.into_optimized_plan()?) - .project(vec![col("c2")])? - .build()?; - - let state = ctx.state(); - let optimized_plan = state.optimize(&logical_plan)?; - match &optimized_plan { - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. - }) => { - assert_eq!(source.schema().fields().len(), 3); - assert_eq!(projected_schema.fields().len(), 1); - } - _ => panic!("input to projection should be TableScan"), - } - - let expected = "TableScan: test projection=[c2]"; - assert_eq!(format!("{optimized_plan:?}"), expected); - - let physical_plan = state.create_physical_plan(&optimized_plan).await?; - - assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); - let batches = collect(physical_plan, state.task_ctx()).await?; - assert_eq!(40, batches.iter().map(|x| x.num_rows()).sum::()); - - Ok(()) -} - -#[tokio::test] -async fn preserve_nullability_on_projection() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = partitioned_csv::create_ctx(&tmp_dir, 1).await?; - - let schema: Schema = ctx.table("test").await.unwrap().schema().clone().into(); - assert!(!schema.field_with_name("c1")?.is_nullable()); - - let plan = scan_empty(None, &schema, None)? - .project(vec![col("c1")])? - .build()?; - - let dataframe = DataFrame::new(ctx.state(), plan); - let physical_plan = dataframe.create_physical_plan().await?; - assert!(!physical_plan.schema().field_with_name("c1")?.is_nullable()); - Ok(()) -} - -#[tokio::test] -async fn project_cast_dictionary() { - let ctx = SessionContext::new(); - - let host: DictionaryArray = vec![Some("host1"), None, Some("host2")] - .into_iter() - .collect(); - - let batch = RecordBatch::try_from_iter(vec![("host", Arc::new(host) as _)]).unwrap(); - - let t = MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap(); - - // Note that `host` is a dictionary array but `lit("")` is a DataType::Utf8 that needs to be cast - let expr = when(col("host").is_null(), lit("")) - .otherwise(col("host")) - .unwrap(); - - let projection = None; - let builder = LogicalPlanBuilder::scan( - "cpu_load_short", - provider_as_source(Arc::new(t)), - projection, - ) - .unwrap(); - - let logical_plan = builder.project(vec![expr]).unwrap().build().unwrap(); - let df = DataFrame::new(ctx.state(), logical_plan); - let actual = df.collect().await.unwrap(); - - let expected = ["+----------------------------------------------------------------------------------+", - "| CASE WHEN cpu_load_short.host IS NULL THEN Utf8(\"\") ELSE cpu_load_short.host END |", - "+----------------------------------------------------------------------------------+", - "| host1 |", - "| |", - "| host2 |", - "+----------------------------------------------------------------------------------+"]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn projection_on_memory_scan() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Int32, false), - ]); - let schema = SchemaRef::new(schema); - - let partitions = vec![vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 10, 10, 100])), - Arc::new(Int32Array::from(vec![2, 12, 12, 120])), - Arc::new(Int32Array::from(vec![3, 12, 12, 120])), - ], - )?]]; - - let provider = Arc::new(MemTable::try_new(schema, partitions)?); - let plan = - LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)? - .project(vec![col("b")])? - .build()?; - assert_fields_eq(&plan, vec!["b"]); - - let ctx = SessionContext::new(); - let state = ctx.state(); - let optimized_plan = state.optimize(&plan)?; - match &optimized_plan { - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. - }) => { - assert_eq!(source.schema().fields().len(), 3); - assert_eq!(projected_schema.fields().len(), 1); - } - _ => panic!("input to projection should be InMemoryScan"), - } - - let expected = format!("TableScan: {UNNAMED_TABLE} projection=[b]"); - assert_eq!(format!("{optimized_plan:?}"), expected); - - let physical_plan = state.create_physical_plan(&optimized_plan).await?; - - assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("b", physical_plan.schema().field(0).name().as_str()); - - let batches = collect(physical_plan, state.task_ctx()).await?; - assert_eq!(1, batches.len()); - assert_eq!(1, batches[0].num_columns()); - assert_eq!(4, batches[0].num_rows()); - - Ok(()) -} - -fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { - let actual: Vec = plan - .schema() - .fields() - .iter() - .map(|f| f.name().clone()) - .collect(); - assert_eq!(actual, expected); -} - -#[tokio::test] -async fn project_column_with_same_name_as_relation() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select a.a from (select 1 as a) as a;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = ["+---+", "| a |", "+---+", "| 1 |", "+---+"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn project_column_with_filters_that_cant_pushed_down_always_false() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select * from (select 1 as a) f where f.a=2;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = ["++", "++"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn project_column_with_filters_that_cant_pushed_down_always_true() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select * from (select 1 as a) f where f.a=1;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = ["+---+", "| a |", "+---+", "| 1 |", "+---+"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn project_columns_in_memory_without_propagation() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select column1 as a from (values (1), (2)) f where f.column1 = 2;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = ["+---+", "| a |", "+---+", "| 2 |", "+---+"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} diff --git a/datafusion/core/tests/sql/repartition.rs b/datafusion/core/tests/sql/repartition.rs index 20e64b2eeefc..332f18e941aa 100644 --- a/datafusion/core/tests/sql/repartition.rs +++ b/datafusion/core/tests/sql/repartition.rs @@ -33,7 +33,7 @@ use std::sync::Arc; #[tokio::test] async fn unbounded_repartition() -> Result<()> { let config = SessionConfig::new(); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let task = ctx.task_ctx(); let schema = Arc::new(Schema::new(vec![Field::new("a2", DataType::UInt32, false)])); let batch = RecordBatch::try_new( diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index cda5fba8051e..cbdea9d72948 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -407,7 +407,8 @@ async fn sort_on_window_null_string() -> Result<()> { ]) .unwrap(); - let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(1)); + let ctx = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(1)); ctx.register_batch("test", batch)?; let sql = @@ -524,6 +525,53 @@ async fn test_prepare_statement() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_named_query_parameters() -> Result<()> { + let tmp_dir = TempDir::new()?; + let partition_count = 4; + let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; + + // sql to statement then to logical plan with parameters + // c1 defined as UINT32, c2 defined as UInt64 + let results = ctx + .sql("SELECT c1, c2 FROM test WHERE c1 > $coo AND c1 < $foo") + .await? + .with_param_values(vec![ + ("foo", ScalarValue::UInt32(Some(3))), + ("coo", ScalarValue::UInt32(Some(0))), + ])? + .collect() + .await?; + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 1 | 1 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 1 | 10 |", + "| 2 | 1 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "| 2 | 10 |", + "+----+----+", + ]; + assert_batches_sorted_eq!(expected, &results); + Ok(()) +} + #[tokio::test] async fn parallel_query_with_filter() -> Result<()> { let tmp_dir = TempDir::new()?; @@ -590,7 +638,7 @@ async fn boolean_literal() -> Result<()> { #[tokio::test] async fn unprojected_filter() { let config = SessionConfig::new(); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let df = ctx.read_table(table_with_sequence(1, 3).unwrap()).unwrap(); let df = df diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs deleted file mode 100644 index 01f8dd684b23..000000000000 --- a/datafusion/core/tests/sql/subqueries.rs +++ /dev/null @@ -1,63 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::*; -use crate::sql::execute_to_batches; - -#[tokio::test] -#[ignore] -async fn correlated_scalar_subquery_sum_agg_bug() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "select t1.t1_int from t1 where (select sum(t2_int) is null from t2 where t1.t1_id = t2.t2_id)"; - - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Projection: t1.t1_int [t1_int:UInt32;N]", - " Inner Join: t1.t1_id = __scalar_sq_1.t2_id [t1_id:UInt32;N, t1_int:UInt32;N, t2_id:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_int] [t1_id:UInt32;N, t1_int:UInt32;N]", - " SubqueryAlias: __scalar_sq_1 [t2_id:UInt32;N]", - " Projection: t2.t2_id [t2_id:UInt32;N]", - " Filter: SUM(t2.t2_int) IS NULL [t2_id:UInt32;N, SUM(t2.t2_int):UInt64;N]", - " Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(t2.t2_int)]] [t2_id:UInt32;N, SUM(t2.t2_int):UInt64;N]", - " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - // assert data - let results = execute_to_batches(&ctx, sql).await; - let expected = [ - "+--------+", - "| t1_int |", - "+--------+", - "| 2 |", - "| 4 |", - "| 3 |", - "+--------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} diff --git a/datafusion/core/tests/sql/timestamp.rs b/datafusion/core/tests/sql/timestamp.rs index 09bbd1754276..ada66503a181 100644 --- a/datafusion/core/tests/sql/timestamp.rs +++ b/datafusion/core/tests/sql/timestamp.rs @@ -567,25 +567,30 @@ async fn timestamp_sub_interval_days() -> Result<()> { #[tokio::test] async fn timestamp_add_interval_months() -> Result<()> { let ctx = SessionContext::new(); + let table_a = + make_timestamp_tz_table::(Some("+00:00".into()))?; + ctx.register_table("table_a", table_a)?; - let sql = "SELECT NOW(), NOW() + INTERVAL '17' MONTH;"; + let sql = "SELECT ts, ts + INTERVAL '17' MONTH FROM table_a;"; let results = execute_to_batches(&ctx, sql).await; - let actual = result_vec(&results); + let actual_vec = result_vec(&results); - let res1 = actual[0][0].as_str(); - let res2 = actual[0][1].as_str(); + for actual in actual_vec { + let res1 = actual[0].as_str(); + let res2 = actual[1].as_str(); - let format = "%Y-%m-%dT%H:%M:%S%.6fZ"; - let t1_naive = NaiveDateTime::parse_from_str(res1, format).unwrap(); - let t2_naive = NaiveDateTime::parse_from_str(res2, format).unwrap(); + let format = "%Y-%m-%dT%H:%M:%S%.6fZ"; + let t1_naive = NaiveDateTime::parse_from_str(res1, format).unwrap(); + let t2_naive = NaiveDateTime::parse_from_str(res2, format).unwrap(); - let year = t1_naive.year() + (t1_naive.month0() as i32 + 17) / 12; - let month = (t1_naive.month0() + 17) % 12 + 1; + let year = t1_naive.year() + (t1_naive.month0() as i32 + 17) / 12; + let month = (t1_naive.month0() + 17) % 12 + 1; - assert_eq!( - t1_naive.with_year(year).unwrap().with_month(month).unwrap(), - t2_naive - ); + assert_eq!( + t1_naive.with_year(year).unwrap().with_month(month).unwrap(), + t2_naive + ); + } Ok(()) } diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 3f55049ecd3c..4db97c75cb33 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -1045,7 +1045,7 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> { let sql = fs::read_to_string(filename).expect("Could not read query"); let config = SessionConfig::default(); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let tables = get_table_definitions(); for table in &tables { ctx.register_table( diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs index ab6f51c47ba7..6c6d966cc3aa 100644 --- a/datafusion/core/tests/user_defined/mod.rs +++ b/datafusion/core/tests/user_defined/mod.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +/// Tests for user defined Scalar functions +mod user_defined_scalar_functions; + /// Tests for User Defined Aggregate Functions mod user_defined_aggregates; @@ -23,3 +26,6 @@ mod user_defined_plan; /// Tests for User Defined Window Functions mod user_defined_window_functions; + +/// Tests for User Defined Table Functions +mod user_defined_table_functions; diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 64547bbdfa36..fb0ecd02c6b0 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -19,11 +19,14 @@ //! user defined aggregate functions use arrow::{array::AsArray, datatypes::Fields}; +use arrow_array::Int32Array; +use arrow_schema::Schema; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; +use datafusion::datasource::MemTable; use datafusion::{ arrow::{ array::{ArrayRef, Float64Array, TimestampNanosecondArray}, @@ -43,6 +46,8 @@ use datafusion::{ use datafusion_common::{ assert_contains, cast::as_primitive_array, exec_err, DataFusionError, }; +use datafusion_expr::create_udaf; +use datafusion_physical_expr::expressions::AvgAccumulator; /// Test to show the contents of the setup #[tokio::test] @@ -169,10 +174,130 @@ async fn test_udaf_returning_struct_subquery() { assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); } +#[tokio::test] +async fn test_udaf_shadows_builtin_fn() { + let TestContext { + mut ctx, + test_state, + } = TestContext::new(); + let sql = "SELECT sum(arrow_cast(time, 'Int64')) from t"; + + // compute with builtin `sum` aggregator + let expected = [ + "+-------------+", + "| SUM(t.time) |", + "+-------------+", + "| 19000 |", + "+-------------+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); + + // Register `TimeSum` with name `sum`. This will shadow the builtin one + let sql = "SELECT sum(time) from t"; + TimeSum::register(&mut ctx, test_state.clone(), "sum"); + let expected = [ + "+----------------------------+", + "| sum(t.time) |", + "+----------------------------+", + "| 1970-01-01T00:00:00.000019 |", + "+----------------------------+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); +} + async fn execute(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } +/// tests the creation, registration and usage of a UDAF +#[tokio::test] +async fn simple_udaf() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let batch1 = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; + let batch2 = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![4, 5]))], + )?; + + let ctx = SessionContext::new(); + + let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Arc::new(provider))?; + + // define a udaf, using a DataFusion's accumulator + let my_avg = create_udaf( + "my_avg", + vec![DataType::Float64], + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(|_| Ok(Box::::default())), + Arc::new(vec![DataType::UInt64, DataType::Float64]), + ); + + ctx.register_udaf(my_avg); + + let result = ctx.sql("SELECT MY_AVG(a) FROM t").await?.collect().await?; + + let expected = [ + "+-------------+", + "| my_avg(t.a) |", + "+-------------+", + "| 3.0 |", + "+-------------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) +} + +#[tokio::test] +async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { + let ctx = SessionContext::new(); + let arr = Int32Array::from(vec![1]); + let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; + ctx.register_batch("t", batch).unwrap(); + + // Note capitalization + let my_avg = create_udaf( + "MY_AVG", + vec![DataType::Float64], + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(|_| Ok(Box::::default())), + Arc::new(vec![DataType::UInt64, DataType::Float64]), + ); + + ctx.register_udaf(my_avg); + + // doesn't work as it was registered as non lowercase + let err = ctx.sql("SELECT MY_AVG(i) FROM t").await.unwrap_err(); + assert!(err + .to_string() + .contains("Error during planning: Invalid function \'my_avg\'")); + + // Can call it if you put quotes + let result = ctx + .sql("SELECT \"MY_AVG\"(i) FROM t") + .await? + .collect() + .await?; + + let expected = [ + "+-------------+", + "| MY_AVG(t.i) |", + "+-------------+", + "| 1.0 |", + "+-------------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) +} + /// Returns an context with a table "t" and the "first" and "time_sum" /// aggregate functions registered. /// @@ -214,7 +339,7 @@ impl TestContext { // Tell DataFusion about the "first" function FirstSelector::register(&mut ctx); // Tell DataFusion about the "time_sum" function - TimeSum::register(&mut ctx, Arc::clone(&test_state)); + TimeSum::register(&mut ctx, Arc::clone(&test_state), "time_sum"); Self { ctx, test_state } } @@ -281,7 +406,7 @@ impl TimeSum { Self { sum: 0, test_state } } - fn register(ctx: &mut SessionContext, test_state: Arc) { + fn register(ctx: &mut SessionContext, test_state: Arc, name: &str) { let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None); // Returns the same type as its input @@ -301,8 +426,6 @@ impl TimeSum { let accumulator: AccumulatorFactoryFunction = Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state))))); - let name = "time_sum"; - let time_sum = AggregateUDF::new(name, &signature, &return_type, &accumulator, &state_type); diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 21ec20f0d4d6..29708c4422ca 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -58,7 +58,9 @@ //! N elements, reducing the total amount of required buffer memory. //! -use futures::{Stream, StreamExt}; +use std::fmt::Debug; +use std::task::{Context, Poll}; +use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; use arrow::{ array::{Int64Array, StringArray}, @@ -68,8 +70,7 @@ use arrow::{ }; use datafusion::{ common::cast::{as_int64_array, as_string_array}, - common::internal_err, - common::DFSchemaRef, + common::{internal_err, DFSchemaRef}, error::{DataFusionError, Result}, execution::{ context::{QueryPlanner, SessionState, TaskContext}, @@ -89,11 +90,9 @@ use datafusion::{ prelude::{SessionConfig, SessionContext}, }; -use fmt::Debug; -use std::task::{Context, Poll}; -use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; - use async_trait::async_trait; +use datafusion_common::arrow_datafusion_err; +use futures::{Stream, StreamExt}; /// Execute the specified sql and return the resulting record batches /// pretty printed as a String. @@ -101,7 +100,7 @@ async fn exec_sql(ctx: &mut SessionContext, sql: &str) -> Result { let df = ctx.sql(sql).await?; let batches = df.collect().await?; pretty_format_batches(&batches) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) .map(|d| d.to_string()) } @@ -247,10 +246,10 @@ async fn topk_plan() -> Result<()> { fn make_topk_context() -> SessionContext { let config = SessionConfig::new().with_target_partitions(48); let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::with_config_rt(config, runtime) + let state = SessionState::new_with_config_rt(config, runtime) .with_query_planner(Arc::new(TopKQueryPlanner {})) .add_optimizer_rule(Arc::new(TopKOptimizerRule {})); - SessionContext::with_state(state) + SessionContext::new_with_state(state) } // ------ The implementation of the TopK code follows ----- @@ -490,10 +489,10 @@ impl ExecutionPlan for TopKExec { })) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { // to improve the optimizability of this plan // better statistics inference could be provided - Statistics::default() + Ok(Statistics::new_unknown(&self.schema())) } } diff --git a/datafusion/core/tests/sql/udf.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs similarity index 63% rename from datafusion/core/tests/sql/udf.rs rename to datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 97512d0249c4..985b0bd5bc76 100644 --- a/datafusion/core/tests/sql/udf.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -15,26 +15,56 @@ // specific language governing permissions and limitations // under the License. -use super::*; use arrow::compute::kernels::numeric::add; +use arrow_array::{ArrayRef, Float64Array, Int32Array, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::prelude::*; use datafusion::{ execution::registry::FunctionRegistry, - physical_plan::{expressions::AvgAccumulator, functions::make_scalar_function}, + physical_plan::functions::make_scalar_function, test_util, }; -use datafusion_common::{cast::as_int32_array, ScalarValue}; -use datafusion_expr::{create_udaf, Accumulator, LogicalPlanBuilder}; +use datafusion_common::cast::as_float64_array; +use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue}; +use datafusion_expr::{ + create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, Volatility, +}; +use std::sync::Arc; /// test that casting happens on udfs. /// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and /// physical plan have the same schema. #[tokio::test] async fn csv_query_custom_udf_with_cast() -> Result<()> { - let ctx = create_ctx(); + let ctx = create_udf_context(); register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; - let actual = execute(&ctx, sql).await; - let expected = vec![vec!["0.6584408483418833"]]; - assert_float_eq(&expected, &actual); + let actual = plan_and_collect(&ctx, sql).await.unwrap(); + let expected = [ + "+------------------------------------------+", + "| AVG(custom_sqrt(aggregate_test_100.c11)) |", + "+------------------------------------------+", + "| 0.6584408483418833 |", + "+------------------------------------------+", + ]; + assert_batches_eq!(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_avg_sqrt() -> Result<()> { + let ctx = create_udf_context(); + register_aggregate_csv(&ctx).await?; + // Note it is a different column (c12) than above (c11) + let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; + let actual = plan_and_collect(&ctx, sql).await.unwrap(); + let expected = [ + "+------------------------------------------+", + "| AVG(custom_sqrt(aggregate_test_100.c12)) |", + "+------------------------------------------+", + "| 0.6706002946036462 |", + "+------------------------------------------+", + ]; + assert_batches_eq!(&expected, &actual); Ok(()) } @@ -212,51 +242,6 @@ async fn scalar_udf_override_built_in_scalar_function() -> Result<()> { Ok(()) } -/// tests the creation, registration and usage of a UDAF -#[tokio::test] -async fn simple_udaf() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - - let batch1 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - )?; - let batch2 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![4, 5]))], - )?; - - let ctx = SessionContext::new(); - - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; - - // define a udaf, using a DataFusion's accumulator - let my_avg = create_udaf( - "my_avg", - vec![DataType::Float64], - Arc::new(DataType::Float64), - Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), - Arc::new(vec![DataType::UInt64, DataType::Float64]), - ); - - ctx.register_udaf(my_avg); - - let result = plan_and_collect(&ctx, "SELECT MY_AVG(a) FROM t").await?; - - let expected = [ - "+-------------+", - "| my_avg(t.a) |", - "+-------------+", - "| 3.0 |", - "+-------------+", - ]; - assert_batches_eq!(expected, &result); - - Ok(()) -} - #[tokio::test] async fn udaf_as_window_func() -> Result<()> { #[derive(Debug)] @@ -314,3 +299,123 @@ async fn udaf_as_window_func() -> Result<()> { assert_eq!(format!("{:?}", dataframe.logical_plan()), expected); Ok(()) } + +#[tokio::test] +async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { + let ctx = SessionContext::new(); + let arr = Int32Array::from(vec![1]); + let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; + ctx.register_batch("t", batch).unwrap(); + + let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); + let myfunc = make_scalar_function(myfunc); + + ctx.register_udf(create_udf( + "MY_FUNC", + vec![DataType::Int32], + Arc::new(DataType::Int32), + Volatility::Immutable, + myfunc, + )); + + // doesn't work as it was registered with non lowercase + let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t") + .await + .unwrap_err(); + assert!(err + .to_string() + .contains("Error during planning: Invalid function \'my_func\'")); + + // Can call it if you put quotes + let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?; + + let expected = [ + "+--------------+", + "| MY_FUNC(t.i) |", + "+--------------+", + "| 1 |", + "+--------------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) +} + +#[tokio::test] +async fn test_user_defined_functions_with_alias() -> Result<()> { + let ctx = SessionContext::new(); + let arr = Int32Array::from(vec![1]); + let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; + ctx.register_batch("t", batch).unwrap(); + + let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); + let myfunc = make_scalar_function(myfunc); + + let udf = create_udf( + "dummy", + vec![DataType::Int32], + Arc::new(DataType::Int32), + Volatility::Immutable, + myfunc, + ) + .with_aliases(vec!["dummy_alias"]); + + ctx.register_udf(udf); + + let expected = [ + "+------------+", + "| dummy(t.i) |", + "+------------+", + "| 1 |", + "+------------+", + ]; + let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?; + assert_batches_eq!(expected, &result); + + let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?; + assert_batches_eq!(expected, &alias_result); + + Ok(()) +} + +fn create_udf_context() -> SessionContext { + let ctx = SessionContext::new(); + // register a custom UDF + ctx.register_udf(create_udf( + "custom_sqrt", + vec![DataType::Float64], + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(custom_sqrt), + )); + + ctx +} + +fn custom_sqrt(args: &[ColumnarValue]) -> Result { + let arg = &args[0]; + if let ColumnarValue::Array(v) = arg { + let input = as_float64_array(v).expect("cast failed"); + let array: Float64Array = input.iter().map(|v| v.map(|x| x.sqrt())).collect(); + Ok(ColumnarValue::Array(Arc::new(array))) + } else { + unimplemented!() + } +} + +async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { + let testdata = datafusion::test_util::arrow_test_data(); + let schema = test_util::aggr_test_schema(); + ctx.register_csv( + "aggregate_test_100", + &format!("{testdata}/csv/aggregate_test_100.csv"), + CsvReadOptions::new().schema(&schema), + ) + .await?; + Ok(()) +} + +/// Execute SQL and return results as a RecordBatch +async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result> { + ctx.sql(sql).await?.collect().await +} diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs new file mode 100644 index 000000000000..b5d10b1c5b9b --- /dev/null +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -0,0 +1,219 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Int64Array; +use arrow::csv::reader::Format; +use arrow::csv::ReaderBuilder; +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; +use datafusion::error::Result; +use datafusion::execution::context::SessionState; +use datafusion::execution::TaskContext; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::prelude::SessionContext; +use datafusion_common::{assert_batches_eq, DFSchema, ScalarValue}; +use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType}; +use std::fs::File; +use std::io::Seek; +use std::path::Path; +use std::sync::Arc; + +/// test simple udtf with define read_csv with parameters +#[tokio::test] +async fn test_simple_read_csv_udtf() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_udtf("read_csv", Arc::new(SimpleCsvTableFunc {})); + + let csv_file = "tests/tpch-csv/nation.csv"; + // read csv with at most 5 rows + let rbs = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}', 5);").as_str()) + .await? + .collect() + .await?; + + let excepted = [ + "+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+", + "| n_nationkey | n_name | n_regionkey | n_comment |", + "+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+", + "| 1 | ARGENTINA | 1 | al foxes promise slyly according to the regular accounts. bold requests alon |", + "| 2 | BRAZIL | 1 | y alongside of the pending deposits. carefully special packages are about the ironic forges. slyly special |", + "| 3 | CANADA | 1 | eas hang ironic, silent packages. slyly regular packages are furiously over the tithes. fluffily bold |", + "| 4 | EGYPT | 4 | y above the carefully unusual theodolites. final dugouts are quickly across the furiously regular d |", + "| 5 | ETHIOPIA | 0 | ven packages wake quickly. regu |", + "+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+", ]; + assert_batches_eq!(excepted, &rbs); + + // just run, return all rows + let rbs = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str()) + .await? + .collect() + .await?; + let excepted = [ + "+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+", + "| n_nationkey | n_name | n_regionkey | n_comment |", + "+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+", + "| 1 | ARGENTINA | 1 | al foxes promise slyly according to the regular accounts. bold requests alon |", + "| 2 | BRAZIL | 1 | y alongside of the pending deposits. carefully special packages are about the ironic forges. slyly special |", + "| 3 | CANADA | 1 | eas hang ironic, silent packages. slyly regular packages are furiously over the tithes. fluffily bold |", + "| 4 | EGYPT | 4 | y above the carefully unusual theodolites. final dugouts are quickly across the furiously regular d |", + "| 5 | ETHIOPIA | 0 | ven packages wake quickly. regu |", + "| 6 | FRANCE | 3 | refully final requests. regular, ironi |", + "| 7 | GERMANY | 3 | l platelets. regular accounts x-ray: unusual, regular acco |", + "| 8 | INDIA | 2 | ss excuses cajole slyly across the packages. deposits print aroun |", + "| 9 | INDONESIA | 2 | slyly express asymptotes. regular deposits haggle slyly. carefully ironic hockey players sleep blithely. carefull |", + "| 10 | IRAN | 4 | efully alongside of the slyly final dependencies. |", + "+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+" + ]; + assert_batches_eq!(excepted, &rbs); + + Ok(()) +} + +struct SimpleCsvTable { + schema: SchemaRef, + exprs: Vec, + batches: Vec, +} + +#[async_trait] +impl TableProvider for SimpleCsvTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let batches = if !self.exprs.is_empty() { + let max_return_lines = self.interpreter_expr(state).await?; + // get max return rows from self.batches + let mut batches = vec![]; + let mut lines = 0; + for batch in &self.batches { + let batch_lines = batch.num_rows(); + if lines + batch_lines > max_return_lines as usize { + let batch_lines = max_return_lines as usize - lines; + batches.push(batch.slice(0, batch_lines)); + break; + } else { + batches.push(batch.clone()); + lines += batch_lines; + } + } + batches + } else { + self.batches.clone() + }; + Ok(Arc::new(MemoryExec::try_new( + &[batches], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} + +impl SimpleCsvTable { + async fn interpreter_expr(&self, state: &SessionState) -> Result { + use datafusion::logical_expr::expr_rewriter::normalize_col; + use datafusion::logical_expr::utils::columnize_expr; + let plan = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: Arc::new(DFSchema::empty()), + }); + let logical_plan = Projection::try_new( + vec![columnize_expr( + normalize_col(self.exprs[0].clone(), &plan)?, + plan.schema(), + )], + Arc::new(plan), + ) + .map(LogicalPlan::Projection)?; + let rbs = collect( + state.create_physical_plan(&logical_plan).await?, + Arc::new(TaskContext::from(state)), + ) + .await?; + let limit = rbs[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + Ok(limit) + } +} + +struct SimpleCsvTableFunc {} + +impl TableFunctionImpl for SimpleCsvTableFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let mut new_exprs = vec![]; + let mut filepath = String::new(); + for expr in exprs { + match expr { + Expr::Literal(ScalarValue::Utf8(Some(ref path))) => { + filepath = path.clone() + } + expr => new_exprs.push(expr.clone()), + } + } + let (schema, batches) = read_csv_batches(filepath)?; + let table = SimpleCsvTable { + schema, + exprs: new_exprs.clone(), + batches, + }; + Ok(Arc::new(table)) + } +} + +fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec)> { + let mut file = File::open(csv_path)?; + let (schema, _) = Format::default() + .with_header(true) + .infer_schema(&mut file, None)?; + file.rewind()?; + + let reader = ReaderBuilder::new(Arc::new(schema.clone())) + .with_header(true) + .build(file)?; + let mut batches = vec![]; + for bacth in reader { + batches.push(bacth?); + } + let schema = Arc::new(schema); + Ok((schema, batches)) +} diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 5f9939157217..54eab4315a97 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -19,6 +19,7 @@ //! user defined window functions use std::{ + any::Any, ops::Range, sync::{ atomic::{AtomicUsize, Ordering}, @@ -32,8 +33,7 @@ use arrow_schema::DataType; use datafusion::{assert_batches_eq, prelude::SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ - function::PartitionEvaluatorFactory, PartitionEvaluator, ReturnTypeFunction, - Signature, Volatility, WindowUDF, + PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, }; /// A query with a window function evaluated over the entire partition @@ -471,24 +471,49 @@ impl OddCounter { } fn register(ctx: &mut SessionContext, test_state: Arc) { - let name = "odd_counter"; - let volatility = Volatility::Immutable; - - let signature = Signature::exact(vec![DataType::Int64], volatility); - - let return_type = Arc::new(DataType::Int64); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::clone(&return_type))); - - let partition_evaluator_factory: PartitionEvaluatorFactory = - Arc::new(move || Ok(Box::new(OddCounter::new(Arc::clone(&test_state))))); - - ctx.register_udwf(WindowUDF::new( - name, - &signature, - &return_type, - &partition_evaluator_factory, - )) + #[derive(Debug, Clone)] + struct SimpleWindowUDF { + signature: Signature, + return_type: DataType, + test_state: Arc, + } + + impl SimpleWindowUDF { + fn new(test_state: Arc) -> Self { + let signature = + Signature::exact(vec![DataType::Float64], Volatility::Immutable); + let return_type = DataType::Int64; + Self { + signature, + return_type, + test_state, + } + } + } + + impl WindowUDFImpl for SimpleWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "odd_counter" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn partition_evaluator(&self) -> Result> { + Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state)))) + } + } + + ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state))) } } diff --git a/datafusion/execution/Cargo.toml b/datafusion/execution/Cargo.toml index cf4eb5ef1f25..e9bb87e9f8ac 100644 --- a/datafusion/execution/Cargo.toml +++ b/datafusion/execution/Cargo.toml @@ -19,9 +19,9 @@ name = "datafusion-execution" description = "Execution configuration support for DataFusion query engine" keywords = ["arrow", "query", "sql"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -35,14 +35,14 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } chrono = { version = "0.4", default-features = false } -dashmap = "5.4.0" -datafusion-common = { path = "../common", version = "31.0.0" } -datafusion-expr = { path = "../expr", version = "31.0.0" } -futures = "0.3" +dashmap = { workspace = true } +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } +futures = { workspace = true } hashbrown = { version = "0.14", features = ["raw"] } -log = "^0.4" -object_store = "0.7.0" -parking_lot = "0.12" -rand = "0.8" -tempfile = "3" -url = "2.2" +log = { workspace = true } +object_store = { workspace = true } +parking_lot = { workspace = true } +rand = { workspace = true } +tempfile = { workspace = true } +url = { workspace = true } diff --git a/datafusion/execution/README.md b/datafusion/execution/README.md new file mode 100644 index 000000000000..67aac6be82b3 --- /dev/null +++ b/datafusion/execution/README.md @@ -0,0 +1,26 @@ + + +# DataFusion Common + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion that provides execution runtime such as the memory pools and disk manager. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/execution/src/cache/cache_manager.rs b/datafusion/execution/src/cache/cache_manager.rs index 987b47bbb8f5..97529263688b 100644 --- a/datafusion/execution/src/cache/cache_manager.rs +++ b/datafusion/execution/src/cache/cache_manager.rs @@ -29,15 +29,25 @@ use std::sync::Arc; pub type FileStatisticsCache = Arc, Extra = ObjectMeta>>; +pub type ListFilesCache = + Arc>, Extra = ObjectMeta>>; + impl Debug for dyn CacheAccessor, Extra = ObjectMeta> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "Cache name: {} with length: {}", self.name(), self.len()) } } +impl Debug for dyn CacheAccessor>, Extra = ObjectMeta> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Cache name: {} with length: {}", self.name(), self.len()) + } +} + #[derive(Default, Debug)] pub struct CacheManager { file_statistic_cache: Option, + list_files_cache: Option, } impl CacheManager { @@ -46,6 +56,9 @@ impl CacheManager { if let Some(cc) = &config.table_files_statistics_cache { manager.file_statistic_cache = Some(cc.clone()) } + if let Some(lc) = &config.list_files_cache { + manager.list_files_cache = Some(lc.clone()) + } Ok(Arc::new(manager)) } @@ -53,6 +66,11 @@ impl CacheManager { pub fn get_file_statistic_cache(&self) -> Option { self.file_statistic_cache.clone() } + + /// Get the cache of objectMeta under same path. + pub fn get_list_files_cache(&self) -> Option { + self.list_files_cache.clone() + } } #[derive(Clone, Default)] @@ -61,6 +79,13 @@ pub struct CacheManagerConfig { /// Avoid get same file statistics repeatedly in same datafusion session. /// Default is disable. Fow now only supports Parquet files. pub table_files_statistics_cache: Option, + /// Enable cache of file metadata when listing files. + /// This setting avoids listing file meta of the same path repeatedly + /// in same session, which may be expensive in certain situations (e.g. remote object storage). + /// Note that if this option is enabled, DataFusion will not see any updates to the underlying + /// location. + /// Default is disable. + pub list_files_cache: Option, } impl CacheManagerConfig { @@ -71,4 +96,9 @@ impl CacheManagerConfig { self.table_files_statistics_cache = cache; self } + + pub fn with_list_files_cache(mut self, cache: Option) -> Self { + self.list_files_cache = cache; + self + } } diff --git a/datafusion/execution/src/cache/cache_unit.rs b/datafusion/execution/src/cache/cache_unit.rs index 3ef699ac2360..25f9b9fa4d68 100644 --- a/datafusion/execution/src/cache/cache_unit.rs +++ b/datafusion/execution/src/cache/cache_unit.rs @@ -15,12 +15,15 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::cache::CacheAccessor; -use dashmap::DashMap; + use datafusion_common::Statistics; + +use dashmap::DashMap; use object_store::path::Path; use object_store::ObjectMeta; -use std::sync::Arc; /// Collected statistics for files /// Cache is invalided when file size or last modification has changed @@ -94,10 +97,71 @@ impl CacheAccessor> for DefaultFileStatisticsCache { } } +/// Collected files metadata for listing files. +/// Cache will not invalided until user call remove or clear. +#[derive(Default)] +pub struct DefaultListFilesCache { + statistics: DashMap>>, +} + +impl CacheAccessor>> for DefaultListFilesCache { + type Extra = ObjectMeta; + + fn get(&self, k: &Path) -> Option>> { + self.statistics.get(k).map(|x| x.value().clone()) + } + + fn get_with_extra( + &self, + _k: &Path, + _e: &Self::Extra, + ) -> Option>> { + panic!("Not supported DefaultListFilesCache get_with_extra") + } + + fn put( + &self, + key: &Path, + value: Arc>, + ) -> Option>> { + self.statistics.insert(key.clone(), value) + } + + fn put_with_extra( + &self, + _key: &Path, + _value: Arc>, + _e: &Self::Extra, + ) -> Option>> { + panic!("Not supported DefaultListFilesCache put_with_extra") + } + + fn remove(&mut self, k: &Path) -> Option>> { + self.statistics.remove(k).map(|x| x.1) + } + + fn contains_key(&self, k: &Path) -> bool { + self.statistics.contains_key(k) + } + + fn len(&self) -> usize { + self.statistics.len() + } + + fn clear(&self) { + self.statistics.clear() + } + + fn name(&self) -> String { + "DefaultListFilesCache".to_string() + } +} + #[cfg(test)] mod tests { - use crate::cache::cache_unit::DefaultFileStatisticsCache; + use crate::cache::cache_unit::{DefaultFileStatisticsCache, DefaultListFilesCache}; use crate::cache::CacheAccessor; + use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use chrono::DateTime; use datafusion_common::Statistics; use object_store::path::Path; @@ -112,12 +176,21 @@ mod tests { .into(), size: 1024, e_tag: None, + version: None, }; - let cache = DefaultFileStatisticsCache::default(); assert!(cache.get_with_extra(&meta.location, &meta).is_none()); - cache.put_with_extra(&meta.location, Statistics::default().into(), &meta); + cache.put_with_extra( + &meta.location, + Statistics::new_unknown(&Schema::new(vec![Field::new( + "test_column", + DataType::Timestamp(TimeUnit::Second, None), + false, + )])) + .into(), + &meta, + ); assert!(cache.get_with_extra(&meta.location, &meta).is_some()); // file size changed @@ -137,4 +210,26 @@ mod tests { meta2.location = Path::from("test2"); assert!(cache.get_with_extra(&meta2.location, &meta2).is_none()); } + + #[test] + fn test_list_file_cache() { + let meta = ObjectMeta { + location: Path::from("test"), + last_modified: DateTime::parse_from_rfc3339("2022-09-27T22:36:00+02:00") + .unwrap() + .into(), + size: 1024, + e_tag: None, + version: None, + }; + + let cache = DefaultListFilesCache::default(); + assert!(cache.get(&meta.location).is_none()); + + cache.put(&meta.location, vec![meta.clone()].into()); + assert_eq!( + cache.get(&meta.location).unwrap().first().unwrap().clone(), + meta.clone() + ); + } } diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 44fcc2ab49b4..8556335b395a 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -86,7 +86,7 @@ impl SessionConfig { /// Set a generic `str` configuration option pub fn set_str(self, key: &str, value: &str) -> Self { - self.set(key, ScalarValue::Utf8(Some(value.to_string()))) + self.set(key, ScalarValue::from(value)) } /// Customize batch size @@ -145,10 +145,12 @@ impl SessionConfig { self.options.optimizer.repartition_sorts } - /// Remove sorts by replacing with order-preserving variants of operators, - /// even when query is bounded? - pub fn bounded_order_preserving_variants(&self) -> bool { - self.options.optimizer.bounded_order_preserving_variants + /// Prefer existing sort (true) or maximize parallelism (false). See + /// [prefer_existing_sort] for more details + /// + /// [prefer_existing_sort]: datafusion_common::config::OptimizerOptions::prefer_existing_sort + pub fn prefer_existing_sort(&self) -> bool { + self.options.optimizer.prefer_existing_sort } /// Are statistics collected during execution? @@ -221,10 +223,12 @@ impl SessionConfig { self } - /// Enables or disables the use of order-preserving variants of `CoalescePartitions` - /// and `RepartitionExec` operators, even when the query is bounded - pub fn with_bounded_order_preserving_variants(mut self, enabled: bool) -> Self { - self.options.optimizer.bounded_order_preserving_variants = enabled; + /// Prefer existing sort (true) or maximize parallelism (false). See + /// [prefer_existing_sort] for more details + /// + /// [prefer_existing_sort]: datafusion_common::config::OptimizerOptions::prefer_existing_sort + pub fn with_prefer_existing_sort(mut self, enabled: bool) -> Self { + self.options.optimizer.prefer_existing_sort = enabled; self } diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index f8fc9fcdbbbb..55555014f2ef 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -46,9 +46,9 @@ pub use pool::*; /// /// The following memory pool implementations are available: /// -/// * [`UnboundedMemoryPool`](pool::UnboundedMemoryPool) -/// * [`GreedyMemoryPool`](pool::GreedyMemoryPool) -/// * [`FairSpillPool`](pool::FairSpillPool) +/// * [`UnboundedMemoryPool`] +/// * [`GreedyMemoryPool`] +/// * [`FairSpillPool`] pub trait MemoryPool: Send + Sync + std::fmt::Debug { /// Registers a new [`MemoryConsumer`] /// @@ -157,6 +157,11 @@ impl MemoryReservation { self.size } + /// Returns [MemoryConsumer] for this [MemoryReservation] + pub fn consumer(&self) -> &MemoryConsumer { + &self.registration.consumer + } + /// Frees all bytes from this reservation back to the underlying /// pool, returning the number of bytes freed. pub fn free(&mut self) -> usize { @@ -230,7 +235,7 @@ impl MemoryReservation { } } - /// Returns a new empty [`MemoryReservation`] with the same [`MemoryConsumer`] + /// Returns a new empty [`MemoryReservation`] with the same [`MemoryConsumer`] pub fn new_empty(&self) -> Self { Self { size: 0, diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index fc49c5fa94c7..4a491630fe20 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -49,7 +49,7 @@ impl MemoryPool for UnboundedMemoryPool { /// A [`MemoryPool`] that implements a greedy first-come first-serve limit. /// /// This pool works well for queries that do not need to spill or have -/// a single spillable operator. See [`GreedyMemoryPool`] if there are +/// a single spillable operator. See [`FairSpillPool`] if there are /// multiple spillable operators that all will spill. #[derive(Debug)] pub struct GreedyMemoryPool { diff --git a/datafusion/execution/src/stream.rs b/datafusion/execution/src/stream.rs index 5a1a9aaa2590..7fc5e458b86b 100644 --- a/datafusion/execution/src/stream.rs +++ b/datafusion/execution/src/stream.rs @@ -29,5 +29,5 @@ pub trait RecordBatchStream: Stream> { fn schema(&self) -> SchemaRef; } -/// Trait for a [`Stream`](futures::stream::Stream) of [`RecordBatch`]es +/// Trait for a [`Stream`] of [`RecordBatch`]es pub type SendableRecordBatchStream = Pin>; diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index 72d804d7bb9a..52c183b1612c 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -22,7 +22,7 @@ use std::{ use datafusion_common::{ config::{ConfigOptions, Extensions}, - DataFusionError, Result, + plan_datafusion_err, DataFusionError, Result, }; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; @@ -182,9 +182,7 @@ impl FunctionRegistry for TaskContext { let result = self.scalar_functions.get(name); result.cloned().ok_or_else(|| { - DataFusionError::Plan(format!( - "There is no UDF named \"{name}\" in the TaskContext" - )) + plan_datafusion_err!("There is no UDF named \"{name}\" in the TaskContext") }) } @@ -192,9 +190,7 @@ impl FunctionRegistry for TaskContext { let result = self.aggregate_functions.get(name); result.cloned().ok_or_else(|| { - DataFusionError::Plan(format!( - "There is no UDAF named \"{name}\" in the TaskContext" - )) + plan_datafusion_err!("There is no UDAF named \"{name}\" in the TaskContext") }) } diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 4d69ce747518..3e05dae61954 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -19,9 +19,9 @@ name = "datafusion-expr" description = "Logical plan and expression representation for DataFusion query engine" keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -35,14 +35,17 @@ path = "src/lib.rs" [features] [dependencies] -ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } arrow = { workspace = true } arrow-array = { workspace = true } -datafusion-common = { path = "../common", version = "31.0.0", default-features = false } +datafusion-common = { workspace = true } +paste = "^1.0" sqlparser = { workspace = true } strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.0" [dev-dependencies] -ctor = "0.2.0" -env_logger = "0.10" +ctor = { workspace = true } +env_logger = { workspace = true } diff --git a/datafusion/expr/README.md b/datafusion/expr/README.md index bcce30be39d9..b086f930e871 100644 --- a/datafusion/expr/README.md +++ b/datafusion/expr/README.md @@ -19,7 +19,7 @@ # DataFusion Logical Plan and Expressions -[DataFusion](df) is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. This crate is a submodule of DataFusion that provides data types and utilities for logical plans and expressions. diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 1c8f34ec1d02..cea72c3cb5e6 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -20,7 +20,7 @@ use crate::utils; use crate::{type_coercion::aggregates::*, Signature, TypeSignature, Volatility}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; use std::sync::Arc; use std::{fmt, str::FromStr}; use strum_macros::EnumIter; @@ -100,10 +100,12 @@ pub enum AggregateFunction { BoolAnd, /// Bool Or BoolOr, + /// string_agg + StringAgg, } impl AggregateFunction { - fn name(&self) -> &str { + pub fn name(&self) -> &str { use AggregateFunction::*; match self { Count => "COUNT", @@ -116,13 +118,13 @@ impl AggregateFunction { ArrayAgg => "ARRAY_AGG", FirstValue => "FIRST_VALUE", LastValue => "LAST_VALUE", - Variance => "VARIANCE", - VariancePop => "VARIANCE_POP", + Variance => "VAR", + VariancePop => "VAR_POP", Stddev => "STDDEV", StddevPop => "STDDEV_POP", - Covariance => "COVARIANCE", - CovariancePop => "COVARIANCE_POP", - Correlation => "CORRELATION", + Covariance => "COVAR", + CovariancePop => "COVAR_POP", + Correlation => "CORR", RegrSlope => "REGR_SLOPE", RegrIntercept => "REGR_INTERCEPT", RegrCount => "REGR_COUNT", @@ -141,6 +143,7 @@ impl AggregateFunction { BitXor => "BIT_XOR", BoolAnd => "BOOL_AND", BoolOr => "BOOL_OR", + StringAgg => "STRING_AGG", } } } @@ -171,6 +174,7 @@ impl FromStr for AggregateFunction { "array_agg" => AggregateFunction::ArrayAgg, "first_value" => AggregateFunction::FirstValue, "last_value" => AggregateFunction::LastValue, + "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, "covar" => AggregateFunction::Covariance, @@ -232,11 +236,14 @@ impl AggregateFunction { // original errors are all related to wrong function signature // aggregate them for better error message .map_err(|_| { - DataFusionError::Plan(utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types, - )) + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{self}"), + self.signature(), + input_expr_types, + ) + ) })?; match self { @@ -296,6 +303,7 @@ impl AggregateFunction { AggregateFunction::FirstValue | AggregateFunction::LastValue => { Ok(coerced_data_types[0].clone()) } + AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), } } } @@ -405,6 +413,30 @@ impl AggregateFunction { .collect(), Volatility::Immutable, ), + AggregateFunction::StringAgg => { + Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use strum::IntoEnumIterator; + + #[test] + // Test for AggregateFuncion's Display and from_str() implementations. + // For each variant in AggregateFuncion, it converts the variant to a string + // and then back to a variant. The test asserts that the original variant and + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 + fn test_display_and_from_str() { + for func_original in AggregateFunction::iter() { + let func_name = func_original.to_string(); + let func_from_str = + AggregateFunction::from_str(func_name.to_lowercase().as_str()).unwrap(); + assert_eq!(func_from_str, func_original); } } } diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index a42963495617..e642dae06e4f 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -17,18 +17,23 @@ //! Built-in functions module contains all the built-in functions definitions. +use std::cmp::Ordering; +use std::collections::HashMap; +use std::fmt; +use std::str::FromStr; +use std::sync::{Arc, OnceLock}; + use crate::nullif::SUPPORTED_NULLIF_TYPES; +use crate::signature::TIMEZONE_WILDCARD; +use crate::type_coercion::binary::get_wider_type; use crate::type_coercion::functions::data_types; use crate::{ - conditional_expressions, struct_expressions, utils, Signature, TypeSignature, - Volatility, + conditional_expressions, FuncMonotonicity, Signature, TypeSignature, Volatility, }; + use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; -use std::collections::HashMap; -use std::fmt; -use std::str::FromStr; -use std::sync::{Arc, OnceLock}; + use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -124,6 +129,8 @@ pub enum BuiltinScalarFunction { // array functions /// array_append ArrayAppend, + /// array_sort + ArraySort, /// array_concat ArrayConcat, /// array_has @@ -132,10 +139,14 @@ pub enum BuiltinScalarFunction { ArrayHasAll, /// array_has_any ArrayHasAny, + /// array_pop_front + ArrayPopFront, /// array_pop_back ArrayPopBack, /// array_dims ArrayDims, + /// array_distinct + ArrayDistinct, /// array_element ArrayElement, /// array_empty @@ -168,12 +179,20 @@ pub enum BuiltinScalarFunction { ArraySlice, /// array_to_string ArrayToString, + /// array_intersect + ArrayIntersect, + /// array_union + ArrayUnion, + /// array_except + ArrayExcept, /// cardinality Cardinality, /// construct an array from columns MakeArray, /// Flatten Flatten, + /// Range + Range, // struct functions /// struct @@ -258,6 +277,8 @@ pub enum BuiltinScalarFunction { ToTimestampMillis, /// to_timestamp_micros ToTimestampMicros, + /// to_timestamp_nanos + ToTimestampNanos, /// to_timestamp_seconds ToTimestampSeconds, /// from_unixtime @@ -280,6 +301,14 @@ pub enum BuiltinScalarFunction { RegexpMatch, /// arrow_typeof ArrowTypeof, + /// overlay + OverLay, + /// levenshtein + Levenshtein, + /// substr_index + SubstrIndex, + /// find_in_set + FindInSet, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -289,8 +318,7 @@ fn name_to_function() -> &'static HashMap<&'static str, BuiltinScalarFunction> { NAME_TO_FUNCTION_LOCK.get_or_init(|| { let mut map = HashMap::new(); BuiltinScalarFunction::iter().for_each(|func| { - let a = aliases(&func); - a.iter().for_each(|&a| { + func.aliases().iter().for_each(|&a| { map.insert(a, func); }); }); @@ -306,7 +334,7 @@ fn function_to_name() -> &'static HashMap { FUNCTION_TO_NAME_LOCK.get_or_init(|| { let mut map = HashMap::new(); BuiltinScalarFunction::iter().for_each(|func| { - map.insert(func, *aliases(&func).first().unwrap_or(&"NO_ALIAS")); + map.insert(func, *func.aliases().first().unwrap_or(&"NO_ALIAS")); }); map }) @@ -315,18 +343,20 @@ fn function_to_name() -> &'static HashMap { impl BuiltinScalarFunction { /// an allowlist of functions to take zero arguments, so that they will get special treatment /// while executing. + #[deprecated( + since = "32.0.0", + note = "please use TypeSignature::supports_zero_argument instead" + )] pub fn supports_zero_argument(&self) -> bool { - matches!( - self, - BuiltinScalarFunction::Pi - | BuiltinScalarFunction::Random - | BuiltinScalarFunction::Now - | BuiltinScalarFunction::CurrentDate - | BuiltinScalarFunction::CurrentTime - | BuiltinScalarFunction::Uuid - | BuiltinScalarFunction::MakeArray - ) + self.signature().type_signature.supports_zero_argument() + } + + /// Returns the name of this function + pub fn name(&self) -> &str { + // .unwrap is safe here because compiler makes sure the map will have matches for each BuiltinScalarFunction + function_to_name().get(self).unwrap() } + /// Returns the [Volatility] of the builtin function. pub fn volatility(&self) -> Volatility { match self { @@ -371,15 +401,19 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Tanh => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::ArrayAppend => Volatility::Immutable, + BuiltinScalarFunction::ArraySort => Volatility::Immutable, BuiltinScalarFunction::ArrayConcat => Volatility::Immutable, BuiltinScalarFunction::ArrayEmpty => Volatility::Immutable, BuiltinScalarFunction::ArrayHasAll => Volatility::Immutable, BuiltinScalarFunction::ArrayHasAny => Volatility::Immutable, BuiltinScalarFunction::ArrayHas => Volatility::Immutable, BuiltinScalarFunction::ArrayDims => Volatility::Immutable, + BuiltinScalarFunction::ArrayDistinct => Volatility::Immutable, BuiltinScalarFunction::ArrayElement => Volatility::Immutable, + BuiltinScalarFunction::ArrayExcept => Volatility::Immutable, BuiltinScalarFunction::ArrayLength => Volatility::Immutable, BuiltinScalarFunction::ArrayNdims => Volatility::Immutable, + BuiltinScalarFunction::ArrayPopFront => Volatility::Immutable, BuiltinScalarFunction::ArrayPopBack => Volatility::Immutable, BuiltinScalarFunction::ArrayPosition => Volatility::Immutable, BuiltinScalarFunction::ArrayPositions => Volatility::Immutable, @@ -394,6 +428,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Flatten => Volatility::Immutable, BuiltinScalarFunction::ArraySlice => Volatility::Immutable, BuiltinScalarFunction::ArrayToString => Volatility::Immutable, + BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable, + BuiltinScalarFunction::ArrayUnion => Volatility::Immutable, + BuiltinScalarFunction::Range => Volatility::Immutable, BuiltinScalarFunction::Cardinality => Volatility::Immutable, BuiltinScalarFunction::MakeArray => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, @@ -436,6 +473,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ToTimestamp => Volatility::Immutable, BuiltinScalarFunction::ToTimestampMillis => Volatility::Immutable, BuiltinScalarFunction::ToTimestampMicros => Volatility::Immutable, + BuiltinScalarFunction::ToTimestampNanos => Volatility::Immutable, BuiltinScalarFunction::ToTimestampSeconds => Volatility::Immutable, BuiltinScalarFunction::Translate => Volatility::Immutable, BuiltinScalarFunction::Trim => Volatility::Immutable, @@ -444,6 +482,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Struct => Volatility::Immutable, BuiltinScalarFunction::FromUnixtime => Volatility::Immutable, BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable, + BuiltinScalarFunction::OverLay => Volatility::Immutable, + BuiltinScalarFunction::Levenshtein => Volatility::Immutable, + BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, + BuiltinScalarFunction::FindInSet => Volatility::Immutable, // Stable builtin functions BuiltinScalarFunction::Now => Volatility::Stable, @@ -465,21 +507,24 @@ impl BuiltinScalarFunction { /// * `List(Int64)` has dimension 2 /// * `List(List(Int64))` has dimension 3 /// * etc. - fn return_dimension(self, input_expr_type: DataType) -> u64 { - let mut res: u64 = 1; + fn return_dimension(self, input_expr_type: &DataType) -> u64 { + let mut result: u64 = 1; let mut current_data_type = input_expr_type; - loop { - match current_data_type { - DataType::List(field) => { - current_data_type = field.data_type().clone(); - res += 1; - } - _ => return res, - } + while let DataType::List(field) = current_data_type { + current_data_type = field.data_type(); + result += 1; } + result } /// Returns the output [`DataType`] of this function + /// + /// This method should be invoked only after `input_expr_types` have been validated + /// against the function's `TypeSignature` using `type_coercion::functions::data_types()`. + /// + /// This method will: + /// 1. Perform additional checks on `input_expr_types` that are beyond the scope of `TypeSignature` validation. + /// 2. Deduce the output `DataType` based on the provided `input_expr_types`. pub fn return_type(self, input_expr_types: &[DataType]) -> Result { use DataType::*; use TimeUnit::*; @@ -487,26 +532,6 @@ impl BuiltinScalarFunction { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. - if input_expr_types.is_empty() && !self.supports_zero_argument() { - return plan_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types - ) - ); - } - - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &self.signature()).map_err(|_| { - DataFusionError::Plan(utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types, - )) - })?; - // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match self { @@ -525,6 +550,7 @@ impl BuiltinScalarFunction { Ok(data_type) } BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArraySort => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayConcat => { let mut expr_type = Null; let mut max_dims = 0; @@ -532,11 +558,17 @@ impl BuiltinScalarFunction { match input_expr_type { List(field) => { if !field.data_type().equals_datatype(&Null) { - let dims = self.return_dimension(input_expr_type.clone()); - if max_dims < dims { - max_dims = dims; - expr_type = input_expr_type.clone(); - } + let dims = self.return_dimension(input_expr_type); + expr_type = match max_dims.cmp(&dims) { + Ordering::Greater => expr_type, + Ordering::Equal => { + get_wider_type(&expr_type, input_expr_type)? + } + Ordering::Less => { + max_dims = dims; + input_expr_type.clone() + } + }; } } _ => { @@ -556,14 +588,17 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayDims => { Ok(List(Arc::new(Field::new("item", UInt64, true)))) } + BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] { List(field) => Ok(field.data_type().clone()), + LargeList(field) => Ok(field.data_type().clone()), _ => plan_err!( - "The {self} function can only accept list as the first argument" + "The {self} function can only accept list or largelist as the first argument" ), }, BuiltinScalarFunction::ArrayLength => Ok(UInt64), BuiltinScalarFunction::ArrayNdims => Ok(UInt64), + BuiltinScalarFunction::ArrayPopFront => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayPopBack => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayPosition => Ok(UInt64), BuiltinScalarFunction::ArrayPositions => { @@ -583,6 +618,35 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), + BuiltinScalarFunction::ArrayIntersect => { + match (input_expr_types[0].clone(), input_expr_types[1].clone()) { + (DataType::Null, DataType::Null) | (DataType::Null, _) => { + Ok(DataType::Null) + } + (_, DataType::Null) => { + Ok(List(Arc::new(Field::new("item", Null, true)))) + } + (dt, _) => Ok(dt), + } + } + BuiltinScalarFunction::ArrayUnion => { + match (input_expr_types[0].clone(), input_expr_types[1].clone()) { + (DataType::Null, dt) => Ok(dt), + (dt, DataType::Null) => Ok(dt), + (dt, _) => Ok(dt), + } + } + BuiltinScalarFunction::Range => { + Ok(List(Arc::new(Field::new("item", Int64, true)))) + } + BuiltinScalarFunction::ArrayExcept => { + match (input_expr_types[0].clone(), input_expr_types[1].clone()) { + (DataType::Null, _) | (_, DataType::Null) => { + Ok(input_expr_types[0].clone()) + } + (dt, _) => Ok(dt), + } + } BuiltinScalarFunction::Cardinality => Ok(UInt64), BuiltinScalarFunction::MakeArray => match input_expr_types.len() { 0 => Ok(List(Arc::new(Field::new("item", Null, true)))), @@ -618,13 +682,20 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => Ok(Utf8), BuiltinScalarFunction::DatePart => Ok(Float64), BuiltinScalarFunction::DateBin | BuiltinScalarFunction::DateTrunc => { - match input_expr_types[1] { - Timestamp(Nanosecond, _) | Utf8 | Null => { + match &input_expr_types[1] { + Timestamp(Nanosecond, None) | Utf8 | Null => { Ok(Timestamp(Nanosecond, None)) } - Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, None)), - Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, None)), - Timestamp(Second, _) => Ok(Timestamp(Second, None)), + Timestamp(Nanosecond, tz_opt) => { + Ok(Timestamp(Nanosecond, tz_opt.clone())) + } + Timestamp(Microsecond, tz_opt) => { + Ok(Timestamp(Microsecond, tz_opt.clone())) + } + Timestamp(Millisecond, tz_opt) => { + Ok(Timestamp(Millisecond, tz_opt.clone())) + } + Timestamp(Second, tz_opt) => Ok(Timestamp(Second, tz_opt.clone())), _ => plan_err!( "The {self} function can only accept timestamp as the second arg." ), @@ -732,7 +803,14 @@ impl BuiltinScalarFunction { return plan_err!("The to_hex function can only accept integers."); } }), - BuiltinScalarFunction::ToTimestamp => Ok(Timestamp(Nanosecond, None)), + BuiltinScalarFunction::SubstrIndex => { + utf8_to_str_type(&input_expr_types[0], "substr_index") + } + BuiltinScalarFunction::FindInSet => { + utf8_to_int_type(&input_expr_types[0], "find_in_set") + } + BuiltinScalarFunction::ToTimestamp + | BuiltinScalarFunction::ToTimestampNanos => Ok(Timestamp(Nanosecond, None)), BuiltinScalarFunction::ToTimestampMillis => Ok(Timestamp(Millisecond, None)), BuiltinScalarFunction::ToTimestampMicros => Ok(Timestamp(Microsecond, None)), BuiltinScalarFunction::ToTimestampSeconds => Ok(Timestamp(Second, None)), @@ -799,6 +877,14 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Abs => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::OverLay => { + utf8_to_str_type(&input_expr_types[0], "overlay") + } + + BuiltinScalarFunction::Levenshtein => { + utf8_to_int_type(&input_expr_types[0], "levenshtein") + } + BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin | BuiltinScalarFunction::Atan @@ -841,7 +927,18 @@ impl BuiltinScalarFunction { // for now, the list is small, as we do not have many built-in functions. match self { - BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArraySort => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayAppend => Signature { + type_signature: ArrayAndElement, + volatility: self.volatility(), + }, + BuiltinScalarFunction::MakeArray => { + // 0 or more arguments of arbitrary type + Signature::one_of(vec![VariadicEqual, Any(0)], self.volatility()) + } + BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayConcat => { Signature::variadic_any(self.volatility()) @@ -849,6 +946,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayEmpty => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()), BuiltinScalarFunction::Flatten => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayHasAll | BuiltinScalarFunction::ArrayHasAny @@ -857,11 +955,15 @@ impl BuiltinScalarFunction { Signature::variadic_any(self.volatility()) } BuiltinScalarFunction::ArrayNdims => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayDistinct => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPosition => { Signature::variadic_any(self.volatility()) } BuiltinScalarFunction::ArrayPositions => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayPrepend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayPrepend => Signature { + type_signature: ElementAndArray, + volatility: self.volatility(), + }, BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemove => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()), @@ -875,14 +977,18 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayToString => { Signature::variadic_any(self.volatility()) } + BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()), BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()), - BuiltinScalarFunction::MakeArray => { - Signature::variadic_any(self.volatility()) - } - BuiltinScalarFunction::Struct => Signature::variadic( - struct_expressions::SUPPORTED_STRUCT_TYPES.to_vec(), + BuiltinScalarFunction::Range => Signature::one_of( + vec![ + Exact(vec![Int64]), + Exact(vec![Int64, Int64]), + Exact(vec![Int64, Int64, Int64]), + ], self.volatility(), ), + BuiltinScalarFunction::Struct => Signature::variadic_any(self.volatility()), BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => { Signature::variadic(vec![Utf8], self.volatility()) @@ -943,6 +1049,7 @@ impl BuiltinScalarFunction { 1, vec![ Int64, + Float64, Timestamp(Nanosecond, None), Timestamp(Microsecond, None), Timestamp(Millisecond, None), @@ -975,6 +1082,18 @@ impl BuiltinScalarFunction { ], self.volatility(), ), + BuiltinScalarFunction::ToTimestampNanos => Signature::uniform( + 1, + vec![ + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, + ], + self.volatility(), + ), BuiltinScalarFunction::ToTimestampSeconds => Signature::uniform( 1, vec![ @@ -1020,13 +1139,25 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::DateTrunc => Signature::one_of( vec![ Exact(vec![Utf8, Timestamp(Nanosecond, None)]), - Exact(vec![Utf8, Timestamp(Nanosecond, Some("+TZ".into()))]), + Exact(vec![ + Utf8, + Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), + ]), Exact(vec![Utf8, Timestamp(Microsecond, None)]), - Exact(vec![Utf8, Timestamp(Microsecond, Some("+TZ".into()))]), + Exact(vec![ + Utf8, + Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + ]), Exact(vec![Utf8, Timestamp(Millisecond, None)]), - Exact(vec![Utf8, Timestamp(Millisecond, Some("+TZ".into()))]), + Exact(vec![ + Utf8, + Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), + ]), Exact(vec![Utf8, Timestamp(Second, None)]), - Exact(vec![Utf8, Timestamp(Second, Some("+TZ".into()))]), + Exact(vec![ + Utf8, + Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), + ]), ], self.volatility(), ), @@ -1040,8 +1171,8 @@ impl BuiltinScalarFunction { ]), Exact(vec![ Interval(MonthDayNano), - Timestamp(array_type.clone(), Some("+TZ".into())), - Timestamp(Nanosecond, Some("+TZ".into())), + Timestamp(array_type.clone(), Some(TIMEZONE_WILDCARD.into())), + Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), ]), Exact(vec![ Interval(DayTime), @@ -1050,8 +1181,8 @@ impl BuiltinScalarFunction { ]), Exact(vec![ Interval(DayTime), - Timestamp(array_type.clone(), Some("+TZ".into())), - Timestamp(Nanosecond, Some("+TZ".into())), + Timestamp(array_type.clone(), Some(TIMEZONE_WILDCARD.into())), + Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), ]), Exact(vec![ Interval(MonthDayNano), @@ -1059,7 +1190,7 @@ impl BuiltinScalarFunction { ]), Exact(vec![ Interval(MonthDayNano), - Timestamp(array_type.clone(), Some("+TZ".into())), + Timestamp(array_type.clone(), Some(TIMEZONE_WILDCARD.into())), ]), Exact(vec![ Interval(DayTime), @@ -1067,7 +1198,7 @@ impl BuiltinScalarFunction { ]), Exact(vec![ Interval(DayTime), - Timestamp(array_type, Some("+TZ".into())), + Timestamp(array_type, Some(TIMEZONE_WILDCARD.into())), ]), ] }; @@ -1082,16 +1213,28 @@ impl BuiltinScalarFunction { } BuiltinScalarFunction::DatePart => Signature::one_of( vec![ - Exact(vec![Utf8, Date32]), - Exact(vec![Utf8, Date64]), - Exact(vec![Utf8, Timestamp(Second, None)]), - Exact(vec![Utf8, Timestamp(Second, Some("+TZ".into()))]), - Exact(vec![Utf8, Timestamp(Microsecond, None)]), - Exact(vec![Utf8, Timestamp(Microsecond, Some("+TZ".into()))]), - Exact(vec![Utf8, Timestamp(Millisecond, None)]), - Exact(vec![Utf8, Timestamp(Millisecond, Some("+TZ".into()))]), Exact(vec![Utf8, Timestamp(Nanosecond, None)]), - Exact(vec![Utf8, Timestamp(Nanosecond, Some("+TZ".into()))]), + Exact(vec![ + Utf8, + Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![Utf8, Timestamp(Millisecond, None)]), + Exact(vec![ + Utf8, + Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![Utf8, Timestamp(Microsecond, None)]), + Exact(vec![ + Utf8, + Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![Utf8, Timestamp(Second, None)]), + Exact(vec![ + Utf8, + Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![Utf8, Date64]), + Exact(vec![Utf8, Date32]), ], self.volatility(), ), @@ -1133,6 +1276,18 @@ impl BuiltinScalarFunction { self.volatility(), ), + BuiltinScalarFunction::SubstrIndex => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + self.volatility(), + ), + BuiltinScalarFunction::FindInSet => Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + self.volatility(), + ), + BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility()) } @@ -1206,7 +1361,19 @@ impl BuiltinScalarFunction { } BuiltinScalarFunction::ArrowTypeof => Signature::any(1, self.volatility()), BuiltinScalarFunction::Abs => Signature::any(1, self.volatility()), - + BuiltinScalarFunction::OverLay => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Levenshtein => Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + self.volatility(), + ), BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin | BuiltinScalarFunction::Atan @@ -1251,186 +1418,251 @@ impl BuiltinScalarFunction { } } } -} -fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { - match func { - BuiltinScalarFunction::Abs => &["abs"], - BuiltinScalarFunction::Acos => &["acos"], - BuiltinScalarFunction::Acosh => &["acosh"], - BuiltinScalarFunction::Asin => &["asin"], - BuiltinScalarFunction::Asinh => &["asinh"], - BuiltinScalarFunction::Atan => &["atan"], - BuiltinScalarFunction::Atanh => &["atanh"], - BuiltinScalarFunction::Atan2 => &["atan2"], - BuiltinScalarFunction::Cbrt => &["cbrt"], - BuiltinScalarFunction::Ceil => &["ceil"], - BuiltinScalarFunction::Cos => &["cos"], - BuiltinScalarFunction::Cot => &["cot"], - BuiltinScalarFunction::Cosh => &["cosh"], - BuiltinScalarFunction::Degrees => &["degrees"], - BuiltinScalarFunction::Exp => &["exp"], - BuiltinScalarFunction::Factorial => &["factorial"], - BuiltinScalarFunction::Floor => &["floor"], - BuiltinScalarFunction::Gcd => &["gcd"], - BuiltinScalarFunction::Isnan => &["isnan"], - BuiltinScalarFunction::Iszero => &["iszero"], - BuiltinScalarFunction::Lcm => &["lcm"], - BuiltinScalarFunction::Ln => &["ln"], - BuiltinScalarFunction::Log => &["log"], - BuiltinScalarFunction::Log10 => &["log10"], - BuiltinScalarFunction::Log2 => &["log2"], - BuiltinScalarFunction::Nanvl => &["nanvl"], - BuiltinScalarFunction::Pi => &["pi"], - BuiltinScalarFunction::Power => &["power", "pow"], - BuiltinScalarFunction::Radians => &["radians"], - BuiltinScalarFunction::Random => &["random"], - BuiltinScalarFunction::Round => &["round"], - BuiltinScalarFunction::Signum => &["signum"], - BuiltinScalarFunction::Sin => &["sin"], - BuiltinScalarFunction::Sinh => &["sinh"], - BuiltinScalarFunction::Sqrt => &["sqrt"], - BuiltinScalarFunction::Tan => &["tan"], - BuiltinScalarFunction::Tanh => &["tanh"], - BuiltinScalarFunction::Trunc => &["trunc"], + /// This function specifies monotonicity behaviors for built-in scalar functions. + /// The list can be extended, only mathematical and datetime functions are + /// considered for the initial implementation of this feature. + pub fn monotonicity(&self) -> Option { + if matches!( + &self, + BuiltinScalarFunction::Atan + | BuiltinScalarFunction::Acosh + | BuiltinScalarFunction::Asinh + | BuiltinScalarFunction::Atanh + | BuiltinScalarFunction::Ceil + | BuiltinScalarFunction::Degrees + | BuiltinScalarFunction::Exp + | BuiltinScalarFunction::Factorial + | BuiltinScalarFunction::Floor + | BuiltinScalarFunction::Ln + | BuiltinScalarFunction::Log10 + | BuiltinScalarFunction::Log2 + | BuiltinScalarFunction::Radians + | BuiltinScalarFunction::Round + | BuiltinScalarFunction::Signum + | BuiltinScalarFunction::Sinh + | BuiltinScalarFunction::Sqrt + | BuiltinScalarFunction::Cbrt + | BuiltinScalarFunction::Tanh + | BuiltinScalarFunction::Trunc + | BuiltinScalarFunction::Pi + ) { + Some(vec![Some(true)]) + } else if matches!( + &self, + BuiltinScalarFunction::DateTrunc | BuiltinScalarFunction::DateBin + ) { + Some(vec![None, Some(true)]) + } else if *self == BuiltinScalarFunction::Log { + Some(vec![Some(true), Some(false)]) + } else { + None + } + } - // conditional functions - BuiltinScalarFunction::Coalesce => &["coalesce"], - BuiltinScalarFunction::NullIf => &["nullif"], + /// Returns all names that can be used to call this function + pub fn aliases(&self) -> &'static [&'static str] { + match self { + BuiltinScalarFunction::Abs => &["abs"], + BuiltinScalarFunction::Acos => &["acos"], + BuiltinScalarFunction::Acosh => &["acosh"], + BuiltinScalarFunction::Asin => &["asin"], + BuiltinScalarFunction::Asinh => &["asinh"], + BuiltinScalarFunction::Atan => &["atan"], + BuiltinScalarFunction::Atanh => &["atanh"], + BuiltinScalarFunction::Atan2 => &["atan2"], + BuiltinScalarFunction::Cbrt => &["cbrt"], + BuiltinScalarFunction::Ceil => &["ceil"], + BuiltinScalarFunction::Cos => &["cos"], + BuiltinScalarFunction::Cot => &["cot"], + BuiltinScalarFunction::Cosh => &["cosh"], + BuiltinScalarFunction::Degrees => &["degrees"], + BuiltinScalarFunction::Exp => &["exp"], + BuiltinScalarFunction::Factorial => &["factorial"], + BuiltinScalarFunction::Floor => &["floor"], + BuiltinScalarFunction::Gcd => &["gcd"], + BuiltinScalarFunction::Isnan => &["isnan"], + BuiltinScalarFunction::Iszero => &["iszero"], + BuiltinScalarFunction::Lcm => &["lcm"], + BuiltinScalarFunction::Ln => &["ln"], + BuiltinScalarFunction::Log => &["log"], + BuiltinScalarFunction::Log10 => &["log10"], + BuiltinScalarFunction::Log2 => &["log2"], + BuiltinScalarFunction::Nanvl => &["nanvl"], + BuiltinScalarFunction::Pi => &["pi"], + BuiltinScalarFunction::Power => &["power", "pow"], + BuiltinScalarFunction::Radians => &["radians"], + BuiltinScalarFunction::Random => &["random"], + BuiltinScalarFunction::Round => &["round"], + BuiltinScalarFunction::Signum => &["signum"], + BuiltinScalarFunction::Sin => &["sin"], + BuiltinScalarFunction::Sinh => &["sinh"], + BuiltinScalarFunction::Sqrt => &["sqrt"], + BuiltinScalarFunction::Tan => &["tan"], + BuiltinScalarFunction::Tanh => &["tanh"], + BuiltinScalarFunction::Trunc => &["trunc"], - // string functions - BuiltinScalarFunction::Ascii => &["ascii"], - BuiltinScalarFunction::BitLength => &["bit_length"], - BuiltinScalarFunction::Btrim => &["btrim"], - BuiltinScalarFunction::CharacterLength => { - &["character_length", "char_length", "length"] - } - BuiltinScalarFunction::Concat => &["concat"], - BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], - BuiltinScalarFunction::Chr => &["chr"], - BuiltinScalarFunction::InitCap => &["initcap"], - BuiltinScalarFunction::Left => &["left"], - BuiltinScalarFunction::Lower => &["lower"], - BuiltinScalarFunction::Lpad => &["lpad"], - BuiltinScalarFunction::Ltrim => &["ltrim"], - BuiltinScalarFunction::OctetLength => &["octet_length"], - BuiltinScalarFunction::Repeat => &["repeat"], - BuiltinScalarFunction::Replace => &["replace"], - BuiltinScalarFunction::Reverse => &["reverse"], - BuiltinScalarFunction::Right => &["right"], - BuiltinScalarFunction::Rpad => &["rpad"], - BuiltinScalarFunction::Rtrim => &["rtrim"], - BuiltinScalarFunction::SplitPart => &["split_part"], - BuiltinScalarFunction::StringToArray => &["string_to_array", "string_to_list"], - BuiltinScalarFunction::StartsWith => &["starts_with"], - BuiltinScalarFunction::Strpos => &["strpos"], - BuiltinScalarFunction::Substr => &["substr"], - BuiltinScalarFunction::ToHex => &["to_hex"], - BuiltinScalarFunction::Translate => &["translate"], - BuiltinScalarFunction::Trim => &["trim"], - BuiltinScalarFunction::Upper => &["upper"], - BuiltinScalarFunction::Uuid => &["uuid"], + // conditional functions + BuiltinScalarFunction::Coalesce => &["coalesce"], + BuiltinScalarFunction::NullIf => &["nullif"], - // regex functions - BuiltinScalarFunction::RegexpMatch => &["regexp_match"], - BuiltinScalarFunction::RegexpReplace => &["regexp_replace"], + // string functions + BuiltinScalarFunction::Ascii => &["ascii"], + BuiltinScalarFunction::BitLength => &["bit_length"], + BuiltinScalarFunction::Btrim => &["btrim"], + BuiltinScalarFunction::CharacterLength => { + &["character_length", "char_length", "length"] + } + BuiltinScalarFunction::Concat => &["concat"], + BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], + BuiltinScalarFunction::Chr => &["chr"], + BuiltinScalarFunction::InitCap => &["initcap"], + BuiltinScalarFunction::Left => &["left"], + BuiltinScalarFunction::Lower => &["lower"], + BuiltinScalarFunction::Lpad => &["lpad"], + BuiltinScalarFunction::Ltrim => &["ltrim"], + BuiltinScalarFunction::OctetLength => &["octet_length"], + BuiltinScalarFunction::Repeat => &["repeat"], + BuiltinScalarFunction::Replace => &["replace"], + BuiltinScalarFunction::Reverse => &["reverse"], + BuiltinScalarFunction::Right => &["right"], + BuiltinScalarFunction::Rpad => &["rpad"], + BuiltinScalarFunction::Rtrim => &["rtrim"], + BuiltinScalarFunction::SplitPart => &["split_part"], + BuiltinScalarFunction::StringToArray => { + &["string_to_array", "string_to_list"] + } + BuiltinScalarFunction::StartsWith => &["starts_with"], + BuiltinScalarFunction::Strpos => &["strpos"], + BuiltinScalarFunction::Substr => &["substr"], + BuiltinScalarFunction::ToHex => &["to_hex"], + BuiltinScalarFunction::Translate => &["translate"], + BuiltinScalarFunction::Trim => &["trim"], + BuiltinScalarFunction::Upper => &["upper"], + BuiltinScalarFunction::Uuid => &["uuid"], + BuiltinScalarFunction::Levenshtein => &["levenshtein"], + BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], + BuiltinScalarFunction::FindInSet => &["find_in_set"], - // time/date functions - BuiltinScalarFunction::Now => &["now"], - BuiltinScalarFunction::CurrentDate => &["current_date"], - BuiltinScalarFunction::CurrentTime => &["current_time"], - BuiltinScalarFunction::DateBin => &["date_bin"], - BuiltinScalarFunction::DateTrunc => &["date_trunc", "datetrunc"], - BuiltinScalarFunction::DatePart => &["date_part", "datepart"], - BuiltinScalarFunction::ToTimestamp => &["to_timestamp"], - BuiltinScalarFunction::ToTimestampMillis => &["to_timestamp_millis"], - BuiltinScalarFunction::ToTimestampMicros => &["to_timestamp_micros"], - BuiltinScalarFunction::ToTimestampSeconds => &["to_timestamp_seconds"], - BuiltinScalarFunction::FromUnixtime => &["from_unixtime"], + // regex functions + BuiltinScalarFunction::RegexpMatch => &["regexp_match"], + BuiltinScalarFunction::RegexpReplace => &["regexp_replace"], - // hashing functions - BuiltinScalarFunction::Digest => &["digest"], - BuiltinScalarFunction::MD5 => &["md5"], - BuiltinScalarFunction::SHA224 => &["sha224"], - BuiltinScalarFunction::SHA256 => &["sha256"], - BuiltinScalarFunction::SHA384 => &["sha384"], - BuiltinScalarFunction::SHA512 => &["sha512"], + // time/date functions + BuiltinScalarFunction::Now => &["now"], + BuiltinScalarFunction::CurrentDate => &["current_date", "today"], + BuiltinScalarFunction::CurrentTime => &["current_time"], + BuiltinScalarFunction::DateBin => &["date_bin"], + BuiltinScalarFunction::DateTrunc => &["date_trunc", "datetrunc"], + BuiltinScalarFunction::DatePart => &["date_part", "datepart"], + BuiltinScalarFunction::ToTimestamp => &["to_timestamp"], + BuiltinScalarFunction::ToTimestampMillis => &["to_timestamp_millis"], + BuiltinScalarFunction::ToTimestampMicros => &["to_timestamp_micros"], + BuiltinScalarFunction::ToTimestampSeconds => &["to_timestamp_seconds"], + BuiltinScalarFunction::ToTimestampNanos => &["to_timestamp_nanos"], + BuiltinScalarFunction::FromUnixtime => &["from_unixtime"], - // encode/decode - BuiltinScalarFunction::Encode => &["encode"], - BuiltinScalarFunction::Decode => &["decode"], + // hashing functions + BuiltinScalarFunction::Digest => &["digest"], + BuiltinScalarFunction::MD5 => &["md5"], + BuiltinScalarFunction::SHA224 => &["sha224"], + BuiltinScalarFunction::SHA256 => &["sha256"], + BuiltinScalarFunction::SHA384 => &["sha384"], + BuiltinScalarFunction::SHA512 => &["sha512"], - // other functions - BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"], + // encode/decode + BuiltinScalarFunction::Encode => &["encode"], + BuiltinScalarFunction::Decode => &["decode"], - // array functions - BuiltinScalarFunction::ArrayAppend => &[ - "array_append", - "list_append", - "array_push_back", - "list_push_back", - ], - BuiltinScalarFunction::ArrayConcat => { - &["array_concat", "array_cat", "list_concat", "list_cat"] - } - BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"], - BuiltinScalarFunction::ArrayEmpty => &["empty"], - BuiltinScalarFunction::ArrayElement => &[ - "array_element", - "array_extract", - "list_element", - "list_extract", - ], - BuiltinScalarFunction::Flatten => &["flatten"], - BuiltinScalarFunction::ArrayHasAll => &["array_has_all", "list_has_all"], - BuiltinScalarFunction::ArrayHasAny => &["array_has_any", "list_has_any"], - BuiltinScalarFunction::ArrayHas => { - &["array_has", "list_has", "array_contains", "list_contains"] - } - BuiltinScalarFunction::ArrayLength => &["array_length", "list_length"], - BuiltinScalarFunction::ArrayNdims => &["array_ndims", "list_ndims"], - BuiltinScalarFunction::ArrayPopBack => &["array_pop_back", "list_pop_back"], - BuiltinScalarFunction::ArrayPosition => &[ - "array_position", - "list_position", - "array_indexof", - "list_indexof", - ], - BuiltinScalarFunction::ArrayPositions => &["array_positions", "list_positions"], - BuiltinScalarFunction::ArrayPrepend => &[ - "array_prepend", - "list_prepend", - "array_push_front", - "list_push_front", - ], - BuiltinScalarFunction::ArrayRepeat => &["array_repeat", "list_repeat"], - BuiltinScalarFunction::ArrayRemove => &["array_remove", "list_remove"], - BuiltinScalarFunction::ArrayRemoveN => &["array_remove_n", "list_remove_n"], - BuiltinScalarFunction::ArrayRemoveAll => &["array_remove_all", "list_remove_all"], - BuiltinScalarFunction::ArrayReplace => &["array_replace", "list_replace"], - BuiltinScalarFunction::ArrayReplaceN => &["array_replace_n", "list_replace_n"], - BuiltinScalarFunction::ArrayReplaceAll => { - &["array_replace_all", "list_replace_all"] - } - BuiltinScalarFunction::ArraySlice => &["array_slice", "list_slice"], - BuiltinScalarFunction::ArrayToString => &[ - "array_to_string", - "list_to_string", - "array_join", - "list_join", - ], - BuiltinScalarFunction::Cardinality => &["cardinality"], - BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], + // other functions + BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"], + + // array functions + BuiltinScalarFunction::ArrayAppend => &[ + "array_append", + "list_append", + "array_push_back", + "list_push_back", + ], + BuiltinScalarFunction::ArraySort => &["array_sort", "list_sort"], + BuiltinScalarFunction::ArrayConcat => { + &["array_concat", "array_cat", "list_concat", "list_cat"] + } + BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"], + BuiltinScalarFunction::ArrayDistinct => &["array_distinct", "list_distinct"], + BuiltinScalarFunction::ArrayEmpty => &["empty"], + BuiltinScalarFunction::ArrayElement => &[ + "array_element", + "array_extract", + "list_element", + "list_extract", + ], + BuiltinScalarFunction::ArrayExcept => &["array_except", "list_except"], + BuiltinScalarFunction::Flatten => &["flatten"], + BuiltinScalarFunction::ArrayHasAll => &["array_has_all", "list_has_all"], + BuiltinScalarFunction::ArrayHasAny => &["array_has_any", "list_has_any"], + BuiltinScalarFunction::ArrayHas => { + &["array_has", "list_has", "array_contains", "list_contains"] + } + BuiltinScalarFunction::ArrayLength => &["array_length", "list_length"], + BuiltinScalarFunction::ArrayNdims => &["array_ndims", "list_ndims"], + BuiltinScalarFunction::ArrayPopFront => { + &["array_pop_front", "list_pop_front"] + } + BuiltinScalarFunction::ArrayPopBack => &["array_pop_back", "list_pop_back"], + BuiltinScalarFunction::ArrayPosition => &[ + "array_position", + "list_position", + "array_indexof", + "list_indexof", + ], + BuiltinScalarFunction::ArrayPositions => { + &["array_positions", "list_positions"] + } + BuiltinScalarFunction::ArrayPrepend => &[ + "array_prepend", + "list_prepend", + "array_push_front", + "list_push_front", + ], + BuiltinScalarFunction::ArrayRepeat => &["array_repeat", "list_repeat"], + BuiltinScalarFunction::ArrayRemove => &["array_remove", "list_remove"], + BuiltinScalarFunction::ArrayRemoveN => &["array_remove_n", "list_remove_n"], + BuiltinScalarFunction::ArrayRemoveAll => { + &["array_remove_all", "list_remove_all"] + } + BuiltinScalarFunction::ArrayReplace => &["array_replace", "list_replace"], + BuiltinScalarFunction::ArrayReplaceN => { + &["array_replace_n", "list_replace_n"] + } + BuiltinScalarFunction::ArrayReplaceAll => { + &["array_replace_all", "list_replace_all"] + } + BuiltinScalarFunction::ArraySlice => &["array_slice", "list_slice"], + BuiltinScalarFunction::ArrayToString => &[ + "array_to_string", + "list_to_string", + "array_join", + "list_join", + ], + BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"], + BuiltinScalarFunction::Cardinality => &["cardinality"], + BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], + BuiltinScalarFunction::ArrayIntersect => { + &["array_intersect", "list_intersect"] + } + BuiltinScalarFunction::OverLay => &["overlay"], + BuiltinScalarFunction::Range => &["range", "generate_series"], - // struct functions - BuiltinScalarFunction::Struct => &["struct"], + // struct functions + BuiltinScalarFunction::Struct => &["struct"], + } } } impl fmt::Display for BuiltinScalarFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - // .unwrap is safe here because compiler makes sure the map will have matches for each BuiltinScalarFunction - write!(f, "{}", function_to_name().get(self).unwrap()) + write!(f, "{}", self.name()) } } @@ -1445,16 +1677,29 @@ impl FromStr for BuiltinScalarFunction { } } +/// Creates a function that returns the return type of a string function given +/// the type of its first argument. +/// +/// If the input type is `LargeUtf8` or `LargeBinary` the return type is +/// `$largeUtf8Type`, +/// +/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, macro_rules! make_utf8_to_return_type { ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { fn $FUNC(arg_type: &DataType, name: &str) -> Result { Ok(match arg_type { - DataType::LargeUtf8 => $largeUtf8Type, - DataType::Utf8 => $utf8Type, + DataType::LargeUtf8 => $largeUtf8Type, + // LargeBinary inputs are automatically coerced to Utf8 + DataType::LargeBinary => $largeUtf8Type, + DataType::Utf8 => $utf8Type, + // Binary inputs are automatically coerced to Utf8 + DataType::Binary => $utf8Type, DataType::Null => DataType::Null, DataType::Dictionary(_, value_type) => match **value_type { - DataType::LargeUtf8 => $largeUtf8Type, + DataType::LargeUtf8 => $largeUtf8Type, + DataType::LargeBinary => $largeUtf8Type, DataType::Utf8 => $utf8Type, + DataType::Binary => $utf8Type, DataType::Null => DataType::Null, _ => { return plan_err!( @@ -1475,8 +1720,10 @@ macro_rules! make_utf8_to_return_type { } }; } - +// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. make_utf8_to_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); + +// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. make_utf8_to_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); fn utf8_or_binary_to_binary_type(arg_type: &DataType, name: &str) -> Result { @@ -1502,7 +1749,8 @@ mod tests { // Test for BuiltinScalarFunction's Display and from_str() implementations. // For each variant in BuiltinScalarFunction, it converts the variant to a string // and then back to a variant. The test asserts that the original variant and - // the reconstructed variant are the same. + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 fn test_display_and_from_str() { for (_, func_original) in name_to_function().iter() { let func_name = func_original.to_string(); diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs new file mode 100644 index 000000000000..a03e3d2d24a9 --- /dev/null +++ b/datafusion/expr/src/built_in_window_function.rs @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Built-in functions module contains all the built-in functions definitions. + +use std::fmt; +use std::str::FromStr; + +use crate::type_coercion::functions::data_types; +use crate::utils; +use crate::{Signature, TypeSignature, Volatility}; +use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; + +use arrow::datatypes::DataType; + +use strum_macros::EnumIter; + +impl fmt::Display for BuiltInWindowFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.name()) + } +} + +/// A [window function] built in to DataFusion +/// +/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) +#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)] +pub enum BuiltInWindowFunction { + /// number of the current row within its partition, counting from 1 + RowNumber, + /// rank of the current row with gaps; same as row_number of its first peer + Rank, + /// rank of the current row without gaps; this function counts peer groups + DenseRank, + /// relative rank of the current row: (rank - 1) / (total rows - 1) + PercentRank, + /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) + CumeDist, + /// integer ranging from 1 to the argument value, dividing the partition as equally as possible + Ntile, + /// returns value evaluated at the row that is offset rows before the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lag, + /// returns value evaluated at the row that is offset rows after the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lead, + /// returns value evaluated at the row that is the first row of the window frame + FirstValue, + /// returns value evaluated at the row that is the last row of the window frame + LastValue, + /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row + NthValue, +} + +impl BuiltInWindowFunction { + fn name(&self) -> &str { + use BuiltInWindowFunction::*; + match self { + RowNumber => "ROW_NUMBER", + Rank => "RANK", + DenseRank => "DENSE_RANK", + PercentRank => "PERCENT_RANK", + CumeDist => "CUME_DIST", + Ntile => "NTILE", + Lag => "LAG", + Lead => "LEAD", + FirstValue => "FIRST_VALUE", + LastValue => "LAST_VALUE", + NthValue => "NTH_VALUE", + } + } +} + +impl FromStr for BuiltInWindowFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + Ok(match name.to_uppercase().as_str() { + "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, + "RANK" => BuiltInWindowFunction::Rank, + "DENSE_RANK" => BuiltInWindowFunction::DenseRank, + "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, + "CUME_DIST" => BuiltInWindowFunction::CumeDist, + "NTILE" => BuiltInWindowFunction::Ntile, + "LAG" => BuiltInWindowFunction::Lag, + "LEAD" => BuiltInWindowFunction::Lead, + "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, + "LAST_VALUE" => BuiltInWindowFunction::LastValue, + "NTH_VALUE" => BuiltInWindowFunction::NthValue, + _ => return plan_err!("There is no built-in window function named {name}"), + }) + } +} + +/// Returns the datatype of the built-in window function +impl BuiltInWindowFunction { + pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + // Note that this function *must* return the same type that the respective physical expression returns + // or the execution panics. + + // verify that this is a valid set of data types for this function + data_types(input_expr_types, &self.signature()) + // original errors are all related to wrong function signature + // aggregate them for better error message + .map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{self}"), + self.signature(), + input_expr_types, + ) + ) + })?; + + match self { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), + BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { + Ok(DataType::Float64) + } + BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), + BuiltInWindowFunction::Lag + | BuiltInWindowFunction::Lead + | BuiltInWindowFunction::FirstValue + | BuiltInWindowFunction::LastValue + | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), + } + } + + /// the signatures supported by the built-in window function `fun`. + pub fn signature(&self) -> Signature { + // note: the physical expression must accept the type returned by this function or the execution panics. + match self { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank + | BuiltInWindowFunction::PercentRank + | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), + BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { + Signature::one_of( + vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ) + } + BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { + Signature::any(1, Volatility::Immutable) + } + BuiltInWindowFunction::Ntile => Signature::uniform( + 1, + vec![ + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + ], + Volatility::Immutable, + ), + BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use strum::IntoEnumIterator; + #[test] + // Test for BuiltInWindowFunction's Display and from_str() implementations. + // For each variant in BuiltInWindowFunction, it converts the variant to a string + // and then back to a variant. The test asserts that the original variant and + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 + fn test_display_and_from_str() { + for func_original in BuiltInWindowFunction::iter() { + let func_name = func_original.to_string(); + let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap(); + assert_eq!(func_from_str, func_original); + } + } +} diff --git a/datafusion/expr/src/columnar_value.rs b/datafusion/expr/src/columnar_value.rs index c72aae69c831..7a2883928169 100644 --- a/datafusion/expr/src/columnar_value.rs +++ b/datafusion/expr/src/columnar_value.rs @@ -20,7 +20,7 @@ use arrow::array::ArrayRef; use arrow::array::NullArray; use arrow::datatypes::DataType; -use datafusion_common::ScalarValue; +use datafusion_common::{Result, ScalarValue}; use std::sync::Arc; /// Represents the result of evaluating an expression: either a single @@ -47,11 +47,15 @@ impl ColumnarValue { /// Convert a columnar value into an ArrayRef. [`Self::Scalar`] is /// converted by repeating the same scalar multiple times. - pub fn into_array(self, num_rows: usize) -> ArrayRef { - match self { + /// + /// # Errors + /// + /// Errors if `self` is a Scalar that fails to be converted into an array of size + pub fn into_array(self, num_rows: usize) -> Result { + Ok(match self { ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows), - } + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows)?, + }) } /// null columnar values are implemented as a null array in order to pass batch diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 94ef69eb7933..5617d217eb9f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,29 +17,33 @@ //! Expr module contains core type definition for `Expr`. -use crate::aggregate_function; -use crate::built_in_function; use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; -use crate::udaf; use crate::utils::{expr_to_columns, find_out_reference_exprs}; use crate::window_frame; -use crate::window_function; + use crate::Operator; +use crate::{aggregate_function, ExprSchemable}; +use crate::{built_in_function, BuiltinScalarFunction}; +use crate::{built_in_window_function, udaf}; use arrow::datatypes::DataType; -use datafusion_common::internal_err; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use std::collections::HashSet; use std::fmt; use std::fmt::{Display, Formatter, Write}; use std::hash::{BuildHasher, Hash, Hasher}; +use std::str::FromStr; use std::sync::Arc; +use crate::Signature; + /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS /// int)`. /// -/// An `Expr` can compute its [DataType](arrow::datatypes::DataType) +/// An `Expr` can compute its [DataType] /// and nullability, and has functions for building up complex /// expressions. /// @@ -98,21 +102,21 @@ pub enum Expr { SimilarTo(Like), /// Negation of an expression. The expression's type must be a boolean to make sense. Not(Box), - /// Whether an expression is not Null. This expression is never null. + /// True if argument is not NULL, false otherwise. This expression itself is never NULL. IsNotNull(Box), - /// Whether an expression is Null. This expression is never null. + /// True if argument is NULL, false otherwise. This expression itself is never NULL. IsNull(Box), - /// Whether an expression is True. Boolean operation + /// True if argument is true, false otherwise. This expression itself is never NULL. IsTrue(Box), - /// Whether an expression is False. Boolean operation + /// True if argument is false, false otherwise. This expression itself is never NULL. IsFalse(Box), - /// Whether an expression is Unknown. Boolean operation + /// True if argument is NULL, false otherwise. This expression itself is never NULL. IsUnknown(Box), - /// Whether an expression is not True. Boolean operation + /// True if argument is FALSE or NULL, false otherwise. This expression itself is never NULL. IsNotTrue(Box), - /// Whether an expression is not False. Boolean operation + /// True if argument is TRUE OR NULL, false otherwise. This expression itself is never NULL. IsNotFalse(Box), - /// Whether an expression is not Unknown. Boolean operation + /// True if argument is TRUE or FALSE, false otherwise. This expression itself is never NULL. IsNotUnknown(Box), /// arithmetic negation of an expression, the operand must be of a signed numeric data type Negative(Box), @@ -147,16 +151,12 @@ pub enum Expr { TryCast(TryCast), /// A sort expression, that can be used to sort values. Sort(Sort), - /// Represents the call of a built-in scalar function with a set of arguments. + /// Represents the call of a scalar function with a set of arguments. ScalarFunction(ScalarFunction), - /// Represents the call of a user-defined scalar function with arguments. - ScalarUDF(ScalarUDF), /// Represents the call of an aggregate built-in function with arguments. AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. WindowFunction(WindowFunction), - /// aggregate function - AggregateUDF(AggregateUDF), /// Returns whether the list contains the expr value. InList(InList), /// EXISTS subquery @@ -165,10 +165,12 @@ pub enum Expr { InSubquery(InSubquery), /// Scalar subquery ScalarSubquery(Subquery), - /// Represents a reference to all fields in a schema. - Wildcard, - /// Represents a reference to all fields in a specific schema. - QualifiedWildcard { qualifier: String }, + /// Represents a reference to all available fields in a specific schema, + /// with an optional (schema) qualifier. + /// + /// This expr has to be resolved to a list of columns before translating logical + /// plan into physical plan. + Wildcard { qualifier: Option }, /// List of grouping set expressions. Only valid in the context of an aggregate /// GROUP BY expression list GroupingSet(GroupingSet), @@ -184,13 +186,20 @@ pub enum Expr { #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Alias { pub expr: Box, + pub relation: Option, pub name: String, } impl Alias { - pub fn new(expr: Expr, name: impl Into) -> Self { + /// Create an alias with an optional schema/field qualifier. + pub fn new( + expr: Expr, + relation: Option>, + name: impl Into, + ) -> Self { Self { expr: Box::new(expr), + relation: relation.map(|r| r.into()), name: name.into(), } } @@ -328,35 +337,80 @@ impl Between { } } -/// ScalarFunction expression +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of a function for DataFusion to call. +pub enum ScalarFunctionDefinition { + /// Resolved to a `BuiltinScalarFunction` + /// There is plan to migrate `BuiltinScalarFunction` to UDF-based implementation (issue#8045) + /// This variant is planned to be removed in long term + BuiltIn(BuiltinScalarFunction), + /// Resolved to a user defined function + UDF(Arc), + /// A scalar function constructed with name. This variant can not be executed directly + /// and instead must be resolved to one of the other variants prior to physical planning. + Name(Arc), +} + +/// ScalarFunction expression invokes a built-in scalar function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct ScalarFunction { /// The function - pub fun: built_in_function::BuiltinScalarFunction, + pub func_def: ScalarFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, } impl ScalarFunction { - /// Create a new ScalarFunction expression - pub fn new(fun: built_in_function::BuiltinScalarFunction, args: Vec) -> Self { - Self { fun, args } + // return the Function's name + pub fn name(&self) -> &str { + self.func_def.name() } } -/// ScalarUDF expression -#[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub struct ScalarUDF { - /// The function - pub fun: Arc, - /// List of expressions to feed to the functions as arguments - pub args: Vec, +impl ScalarFunctionDefinition { + /// Function's name for display + pub fn name(&self) -> &str { + match self { + ScalarFunctionDefinition::BuiltIn(fun) => fun.name(), + ScalarFunctionDefinition::UDF(udf) => udf.name(), + ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(), + } + } + + /// Whether this function is volatile, i.e. whether it can return different results + /// when evaluated multiple times with the same input. + pub fn is_volatile(&self) -> Result { + match self { + ScalarFunctionDefinition::BuiltIn(fun) => { + Ok(fun.volatility() == crate::Volatility::Volatile) + } + ScalarFunctionDefinition::UDF(udf) => { + Ok(udf.signature().volatility == crate::Volatility::Volatile) + } + ScalarFunctionDefinition::Name(func) => { + internal_err!( + "Cannot determine volatility of unresolved function: {func}" + ) + } + } + } } -impl ScalarUDF { - /// Create a new ScalarUDF expression - pub fn new(fun: Arc, args: Vec) -> Self { - Self { fun, args } +impl ScalarFunction { + /// Create a new ScalarFunction expression + pub fn new(fun: built_in_function::BuiltinScalarFunction, args: Vec) -> Self { + Self { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + } + } + + /// Create a new ScalarFunction expression with a user-defined function (UDF) + pub fn new_udf(udf: Arc, args: Vec) -> Self { + Self { + func_def: ScalarFunctionDefinition::UDF(udf), + args, + } } } @@ -443,11 +497,33 @@ impl Sort { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of an aggregate function DataFusion should call. +pub enum AggregateFunctionDefinition { + BuiltIn(aggregate_function::AggregateFunction), + /// Resolved to a user defined aggregate function + UDF(Arc), + /// A aggregation function constructed with name. This variant can not be executed directly + /// and instead must be resolved to one of the other variants prior to physical planning. + Name(Arc), +} + +impl AggregateFunctionDefinition { + /// Function's name for display + pub fn name(&self) -> &str { + match self { + AggregateFunctionDefinition::BuiltIn(fun) => fun.name(), + AggregateFunctionDefinition::UDF(udf) => udf.name(), + AggregateFunctionDefinition::Name(func_name) => func_name.as_ref(), + } + } +} + /// Aggregate function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct AggregateFunction { /// Name of the function - pub fun: aggregate_function::AggregateFunction, + pub func_def: AggregateFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, /// Whether this is a DISTINCT aggregation or not @@ -467,20 +543,90 @@ impl AggregateFunction { order_by: Option>, ) -> Self { Self { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), args, distinct, filter, order_by, } } + + /// Create a new AggregateFunction expression with a user-defined function (UDF) + pub fn new_udf( + udf: Arc, + args: Vec, + distinct: bool, + filter: Option>, + order_by: Option>, + ) -> Self { + Self { + func_def: AggregateFunctionDefinition::UDF(udf), + args, + distinct, + filter, + order_by, + } + } +} + +/// WindowFunction +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of an aggregate function DataFusion should call. +pub enum WindowFunctionDefinition { + /// A built in aggregate function that leverages an aggregate function + AggregateFunction(aggregate_function::AggregateFunction), + /// A a built-in window function + BuiltInWindowFunction(built_in_window_function::BuiltInWindowFunction), + /// A user defined aggregate function + AggregateUDF(Arc), + /// A user defined aggregate function + WindowUDF(Arc), +} + +impl WindowFunctionDefinition { + /// Returns the datatype of the window function + pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::AggregateUDF(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::WindowUDF(fun) => fun.return_type(input_expr_types), + } + } + + /// the signatures supported by the function `fun`. + pub fn signature(&self) -> Signature { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => fun.signature(), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.signature(), + WindowFunctionDefinition::AggregateUDF(fun) => fun.signature().clone(), + WindowFunctionDefinition::WindowUDF(fun) => fun.signature().clone(), + } + } +} + +impl fmt::Display for WindowFunctionDefinition { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => fun.fmt(f), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.fmt(f), + WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), + WindowFunctionDefinition::WindowUDF(fun) => fun.fmt(f), + } + } } /// Window function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct WindowFunction { /// Name of the function - pub fun: window_function::WindowFunction, + pub fun: WindowFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, /// List of partition by expressions @@ -494,7 +640,7 @@ pub struct WindowFunction { impl WindowFunction { /// Create a new Window expression pub fn new( - fun: window_function::WindowFunction, + fun: WindowFunctionDefinition, args: Vec, partition_by: Vec, order_by: Vec, @@ -510,6 +656,50 @@ impl WindowFunction { } } +/// Find DataFusion's built-in window function by name. +pub fn find_df_window_func(name: &str) -> Option { + let name = name.to_lowercase(); + // Code paths for window functions leveraging ordinary aggregators and + // built-in window functions are quite different, and the same function + // may have different implementations for these cases. If the sought + // function is not found among built-in window functions, we search for + // it among aggregate functions. + if let Ok(built_in_function) = + built_in_window_function::BuiltInWindowFunction::from_str(name.as_str()) + { + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_function, + )) + } else if let Ok(aggregate) = + aggregate_function::AggregateFunction::from_str(name.as_str()) + { + Some(WindowFunctionDefinition::AggregateFunction(aggregate)) + } else { + None + } +} + +/// Returns the datatype of the window function +#[deprecated( + since = "27.0.0", + note = "please use `WindowFunction::return_type` instead" +)] +pub fn return_type( + fun: &WindowFunctionDefinition, + input_expr_types: &[DataType], +) -> Result { + fun.return_type(input_expr_types) +} + +/// the signatures supported by the function `fun`. +#[deprecated( + since = "27.0.0", + note = "please use `WindowFunction::signature` instead" +)] +pub fn signature(fun: &WindowFunctionDefinition) -> Signature { + fun.signature() +} + // Exists expression. #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Exists { @@ -599,10 +789,13 @@ impl InSubquery { } } -/// Placeholder +/// Placeholder, representing bind parameter values such as `$1` or `$name`. +/// +/// The type of these parameters is inferred using [`Expr::infer_placeholder_types`] +/// or can be specified directly using `PREPARE` statements. #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Placeholder { - /// The identifier of the parameter (e.g, $1 or $foo) + /// The identifier of the parameter, including the leading `$` (e.g, `"$1"` or `"$foo"`) pub id: String, /// The type the parameter will be filled in with pub data_type: Option, @@ -692,7 +885,6 @@ impl Expr { pub fn variant_name(&self) -> &str { match self { Expr::AggregateFunction { .. } => "AggregateFunction", - Expr::AggregateUDF { .. } => "AggregateUDF", Expr::Alias(..) => "Alias", Expr::Between { .. } => "Between", Expr::BinaryExpr { .. } => "BinaryExpr", @@ -719,15 +911,13 @@ impl Expr { Expr::Negative(..) => "Negative", Expr::Not(..) => "Not", Expr::Placeholder(_) => "Placeholder", - Expr::QualifiedWildcard { .. } => "QualifiedWildcard", Expr::ScalarFunction(..) => "ScalarFunction", Expr::ScalarSubquery { .. } => "ScalarSubquery", - Expr::ScalarUDF(..) => "ScalarUDF", Expr::ScalarVariable(..) => "ScalarVariable", Expr::Sort { .. } => "Sort", Expr::TryCast { .. } => "TryCast", Expr::WindowFunction { .. } => "WindowFunction", - Expr::Wildcard => "Wildcard", + Expr::Wildcard { .. } => "Wildcard", } } @@ -839,14 +1029,34 @@ impl Expr { asc, nulls_first, }) => Expr::Sort(Sort::new(Box::new(expr.alias(name)), asc, nulls_first)), - _ => Expr::Alias(Alias::new(self, name.into())), + _ => Expr::Alias(Alias::new(self, None::<&str>, name.into())), + } + } + + /// Return `self AS name` alias expression with a specific qualifier + pub fn alias_qualified( + self, + relation: Option>, + name: impl Into, + ) -> Expr { + match self { + Expr::Sort(Sort { + expr, + asc, + nulls_first, + }) => Expr::Sort(Sort::new( + Box::new(expr.alias_qualified(relation, name)), + asc, + nulls_first, + )), + _ => Expr::Alias(Alias::new(self, relation, name.into())), } } /// Remove an alias from an expression if one exists. pub fn unalias(self) -> Expr { match self { - Expr::Alias(alias) => alias.expr.as_ref().clone(), + Expr::Alias(alias) => *alias.expr, _ => self, } } @@ -952,7 +1162,7 @@ impl Expr { Expr::GetIndexedField(GetIndexedField { expr: Box::new(self), field: GetFieldAccess::NamedStructField { - name: ScalarValue::Utf8(Some(name.into())), + name: ScalarValue::from(name.into()), }, }) } @@ -1030,6 +1240,52 @@ impl Expr { pub fn contains_outer(&self) -> bool { !find_out_reference_exprs(self).is_empty() } + + /// Recursively find all [`Expr::Placeholder`] expressions, and + /// to infer their [`DataType`] from the context of their use. + /// + /// For example, gicen an expression like ` = $0` will infer `$0` to + /// have type `int32`. + pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result { + self.transform(&|mut expr| { + // Default to assuming the arguments are the same type + if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { + rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; + rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; + }; + if let Expr::Between(Between { + expr, + negated: _, + low, + high, + }) = &mut expr + { + rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; + rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; + } + Ok(Transformed::Yes(expr)) + }) + } +} + +// modifies expr if it is a placeholder with datatype of right +fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { + if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr { + if data_type.is_none() { + let other_dt = other.get_type(schema); + match other_dt { + Err(e) => { + Err(e.context(format!( + "Can not find type of {other} needed to infer type of {expr}" + )))?; + } + Ok(dt) => { + *data_type = Some(dt); + } + } + }; + } + Ok(()) } #[macro_export] @@ -1118,11 +1374,8 @@ impl fmt::Display for Expr { write!(f, " NULLS LAST") } } - Expr::ScalarFunction(func) => { - fmt_function(f, &func.fun.to_string(), false, &func.args, true) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - fmt_function(f, &fun.name, false, args, true) + Expr::ScalarFunction(fun) => { + fmt_function(f, fun.name(), false, &fun.args, true) } Expr::WindowFunction(WindowFunction { fun, @@ -1146,30 +1399,14 @@ impl fmt::Display for Expr { Ok(()) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, ref args, filter, order_by, .. }) => { - fmt_function(f, &fun.to_string(), *distinct, args, true)?; - if let Some(fe) = filter { - write!(f, " FILTER (WHERE {fe})")?; - } - if let Some(ob) = order_by { - write!(f, " ORDER BY [{}]", expr_vec_fmt!(ob))?; - } - Ok(()) - } - Expr::AggregateUDF(AggregateUDF { - fun, - ref args, - filter, - order_by, - .. - }) => { - fmt_function(f, &fun.name, false, args, true)?; + fmt_function(f, func_def.name(), *distinct, args, true)?; if let Some(fe) = filter { write!(f, " FILTER (WHERE {fe})")?; } @@ -1236,8 +1473,10 @@ impl fmt::Display for Expr { write!(f, "{expr} IN ([{}])", expr_vec_fmt!(list)) } } - Expr::Wildcard => write!(f, "*"), - Expr::QualifiedWildcard { qualifier } => write!(f, "{qualifier}.*"), + Expr::Wildcard { qualifier } => match qualifier { + Some(qualifier) => write!(f, "{qualifier}.*"), + None => write!(f, "*"), + }, Expr::GetIndexedField(GetIndexedField { field, expr }) => match field { GetFieldAccess::NamedStructField { name } => { write!(f, "({expr})[{name}]") @@ -1452,12 +1691,7 @@ fn create_name(e: &Expr) -> Result { } } } - Expr::ScalarFunction(func) => { - create_function_name(&func.fun.to_string(), false, &func.args) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - create_function_name(&fun.name, false, args) - } + Expr::ScalarFunction(fun) => create_function_name(fun.name(), false, &fun.args), Expr::WindowFunction(WindowFunction { fun, args, @@ -1477,39 +1711,39 @@ fn create_name(e: &Expr) -> Result { Ok(parts.join(" ")) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, filter, order_by, }) => { - let mut name = create_function_name(&fun.to_string(), *distinct, args)?; - if let Some(fe) = filter { - name = format!("{name} FILTER (WHERE {fe})"); - }; - if let Some(order_by) = order_by { - name = format!("{name} ORDER BY [{}]", expr_vec_fmt!(order_by)); + let name = match func_def { + AggregateFunctionDefinition::BuiltIn(..) + | AggregateFunctionDefinition::Name(..) => { + create_function_name(func_def.name(), *distinct, args)? + } + AggregateFunctionDefinition::UDF(..) => { + let names: Vec = + args.iter().map(create_name).collect::>()?; + names.join(",") + } }; - Ok(name) - } - Expr::AggregateUDF(AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_name(e)?); - } let mut info = String::new(); if let Some(fe) = filter { info += &format!(" FILTER (WHERE {fe})"); + }; + if let Some(order_by) = order_by { + info += &format!(" ORDER BY [{}]", expr_vec_fmt!(order_by)); + }; + match func_def { + AggregateFunctionDefinition::BuiltIn(..) + | AggregateFunctionDefinition::Name(..) => { + Ok(format!("{}{}", name, info)) + } + AggregateFunctionDefinition::UDF(fun) => { + Ok(format!("{}({}){}", fun.name(), name, info)) + } } - if let Some(ob) = order_by { - info += &format!(" ORDER BY ([{}])", expr_vec_fmt!(ob)); - } - Ok(format!("{}({}){}", fun.name, names.join(","), info)) } Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => { @@ -1557,10 +1791,12 @@ fn create_name(e: &Expr) -> Result { Expr::Sort { .. } => { internal_err!("Create name does not support sort expression") } - Expr::Wildcard => Ok("*".to_string()), - Expr::QualifiedWildcard { .. } => { - internal_err!("Create name does not support qualified wildcard") - } + Expr::Wildcard { qualifier } => match qualifier { + Some(qualifier) => internal_err!( + "Create name does not support qualified wildcard, got {qualifier}" + ), + None => Ok("*".to_string()), + }, Expr::Placeholder(Placeholder { id, .. }) => Ok((*id).to_string()), } } @@ -1574,14 +1810,28 @@ fn create_names(exprs: &[Expr]) -> Result { .join(", ")) } +/// Whether the given expression is volatile, i.e. whether it can return different results +/// when evaluated multiple times with the same input. +pub fn is_volatile(expr: &Expr) -> Result { + match expr { + Expr::ScalarFunction(func) => func.func_def.is_volatile(), + _ => Ok(false), + } +} + #[cfg(test)] mod test { use crate::expr::Cast; use crate::expr_fn::col; - use crate::{case, lit, Expr}; + use crate::{ + case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ScalarFunctionDefinition, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, + }; use arrow::datatypes::DataType; use datafusion_common::Column; use datafusion_common::{Result, ScalarValue}; + use std::any::Any; + use std::sync::Arc; #[test] fn format_case_when() -> Result<()> { @@ -1646,4 +1896,282 @@ mod test { Ok(()) } + + #[test] + fn test_logical_ops() { + assert_eq!( + format!("{}", lit(1u32).eq(lit(2u32))), + "UInt32(1) = UInt32(2)" + ); + assert_eq!( + format!("{}", lit(1u32).not_eq(lit(2u32))), + "UInt32(1) != UInt32(2)" + ); + assert_eq!( + format!("{}", lit(1u32).gt(lit(2u32))), + "UInt32(1) > UInt32(2)" + ); + assert_eq!( + format!("{}", lit(1u32).gt_eq(lit(2u32))), + "UInt32(1) >= UInt32(2)" + ); + assert_eq!( + format!("{}", lit(1u32).lt(lit(2u32))), + "UInt32(1) < UInt32(2)" + ); + assert_eq!( + format!("{}", lit(1u32).lt_eq(lit(2u32))), + "UInt32(1) <= UInt32(2)" + ); + assert_eq!( + format!("{}", lit(1u32).and(lit(2u32))), + "UInt32(1) AND UInt32(2)" + ); + assert_eq!( + format!("{}", lit(1u32).or(lit(2u32))), + "UInt32(1) OR UInt32(2)" + ); + } + + #[test] + fn test_is_volatile_scalar_func_definition() { + // BuiltIn + assert!( + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random) + .is_volatile() + .unwrap() + ); + assert!( + !ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs) + .is_volatile() + .unwrap() + ); + + // UDF + #[derive(Debug)] + struct TestScalarUDF { + signature: Signature, + } + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "TestScalarUDF" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) + } + } + let udf = Arc::new(ScalarUDF::from(TestScalarUDF { + signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + })); + assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); + + let udf = Arc::new(ScalarUDF::from(TestScalarUDF { + signature: Signature::uniform( + 1, + vec![DataType::Float32], + Volatility::Volatile, + ), + })); + assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); + + // Unresolved function + ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc")) + .is_volatile() + .expect_err("Shouldn't determine volatility of unresolved function"); + } + + use super::*; + + #[test] + fn test_count_return_type() -> Result<()> { + let fun = find_df_window_func("count").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Int64, observed); + + let observed = fun.return_type(&[DataType::UInt64])?; + assert_eq!(DataType::Int64, observed); + + Ok(()) + } + + #[test] + fn test_first_value_return_type() -> Result<()> { + let fun = find_df_window_func("first_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::UInt64])?; + assert_eq!(DataType::UInt64, observed); + + Ok(()) + } + + #[test] + fn test_last_value_return_type() -> Result<()> { + let fun = find_df_window_func("last_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_lead_return_type() -> Result<()> { + let fun = find_df_window_func("lead").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_lag_return_type() -> Result<()> { + let fun = find_df_window_func("lag").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_nth_value_return_type() -> Result<()> { + let fun = find_df_window_func("nth_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_percent_rank_return_type() -> Result<()> { + let fun = find_df_window_func("percent_rank").unwrap(); + let observed = fun.return_type(&[])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_cume_dist_return_type() -> Result<()> { + let fun = find_df_window_func("cume_dist").unwrap(); + let observed = fun.return_type(&[])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_ntile_return_type() -> Result<()> { + let fun = find_df_window_func("ntile").unwrap(); + let observed = fun.return_type(&[DataType::Int16])?; + assert_eq!(DataType::UInt64, observed); + + Ok(()) + } + + #[test] + fn test_window_function_case_insensitive() -> Result<()> { + let names = vec![ + "row_number", + "rank", + "dense_rank", + "percent_rank", + "cume_dist", + "ntile", + "lag", + "lead", + "first_value", + "last_value", + "nth_value", + "min", + "max", + "count", + "avg", + "sum", + ]; + for name in names { + let fun = find_df_window_func(name).unwrap(); + let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); + assert_eq!(fun, fun2); + assert_eq!(fun.to_string(), name.to_uppercase()); + } + Ok(()) + } + + #[test] + fn test_find_df_window_function() { + assert_eq!( + find_df_window_func("max"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Max + )) + ); + assert_eq!( + find_df_window_func("min"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Min + )) + ); + assert_eq!( + find_df_window_func("avg"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Avg + )) + ); + assert_eq!( + find_df_window_func("cume_dist"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::CumeDist + )) + ); + assert_eq!( + find_df_window_func("first_value"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::FirstValue + )) + ); + assert_eq!( + find_df_window_func("LAST_value"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::LastValue + )) + ); + assert_eq!( + find_df_window_func("LAG"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::Lag + )) + ); + assert_eq!( + find_df_window_func("LEAD"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::Lead + )) + ); + assert_eq!(find_df_window_func("not_exist"), None) + } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 711dc123a4a4..0491750d18a9 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -19,18 +19,21 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, - ScalarFunction, TryCast, + Placeholder, ScalarFunction, TryCast, }; use crate::function::PartitionEvaluatorFactory; -use crate::WindowUDF; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, }; +use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; +use std::any::Any; +use std::fmt::Debug; +use std::ops::Not; use std::sync::Arc; /// Create a column expression based on a qualified or unqualified column name. Will @@ -80,6 +83,37 @@ pub fn ident(name: impl Into) -> Expr { Expr::Column(Column::from_name(name)) } +/// Create placeholder value that will be filled in (such as `$1`) +/// +/// Note the parameter type can be inferred using [`Expr::infer_placeholder_types`] +/// +/// # Example +/// +/// ```rust +/// # use datafusion_expr::{placeholder}; +/// let p = placeholder("$0"); // $0, refers to parameter 1 +/// assert_eq!(p.to_string(), "$0") +/// ``` +pub fn placeholder(id: impl Into) -> Expr { + Expr::Placeholder(Placeholder { + id: id.into(), + data_type: None, + }) +} + +/// Create an '*' [`Expr::Wildcard`] expression that matches all columns +/// +/// # Example +/// +/// ```rust +/// # use datafusion_expr::{wildcard}; +/// let p = wildcard(); +/// assert_eq!(p.to_string(), "*") +/// ``` +pub fn wildcard() -> Expr { + Expr::Wildcard { qualifier: None } +} + /// Return a new expression `left right` pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) @@ -103,6 +137,11 @@ pub fn or(left: Expr, right: Expr) -> Expr { )) } +/// Return a new expression with a logical NOT +pub fn not(expr: Expr) -> Expr { + expr.not() +} + /// Create an expression to represent the min() aggregate function pub fn min(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( @@ -546,6 +585,8 @@ scalar_expr!( "appends an element to the end of an array." ); +scalar_expr!(ArraySort, array_sort, array desc null_first, "returns sorted array."); + scalar_expr!( ArrayPopBack, array_pop_back, @@ -553,6 +594,13 @@ scalar_expr!( "returns the array without the last element." ); +scalar_expr!( + ArrayPopFront, + array_pop_front, + array, + "returns the array without the first element." +); + nary_scalar_expr!(ArrayConcat, array_concat, "concatenates arrays."); scalar_expr!( ArrayHas, @@ -596,6 +644,12 @@ scalar_expr!( array element, "extracts the element with the index n from the array." ); +scalar_expr!( + ArrayExcept, + array_except, + first_array second_array, + "Returns an array of the elements that appear in the first array but not in the second." +); scalar_expr!( ArrayLength, array_length, @@ -608,6 +662,12 @@ scalar_expr!( array, "returns the number of dimensions of the array." ); +scalar_expr!( + ArrayDistinct, + array_distinct, + array, + "return distinct values from the array after removing duplicates." +); scalar_expr!( ArrayPosition, array_position, @@ -680,6 +740,8 @@ scalar_expr!( array delimiter, "converts each element to its text representation." ); +scalar_expr!(ArrayUnion, array_union, array1 array2, "returns an array of the elements in the union of array1 and array2 without duplicates."); + scalar_expr!( Cardinality, cardinality, @@ -691,6 +753,18 @@ nary_scalar_expr!( array, "returns an Arrow array using the specified input expressions." ); +scalar_expr!( + ArrayIntersect, + array_intersect, + first_array second_array, + "Returns an array of the elements in the intersection of array1 and array2." +); + +nary_scalar_expr!( + Range, + gen_range, + "Returns a list of values in the range between start and stop with step." +); // string functions scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character"); @@ -793,6 +867,11 @@ nary_scalar_expr!( "concatenates several strings, placing a seperator between each one" ); nary_scalar_expr!(Concat, concat_expr, "concatenates several strings"); +nary_scalar_expr!( + OverLay, + overlay, + "replace the substring of string that starts at the start'th character and extends for count characters with new substring" +); // date functions scalar_expr!(DatePart, date_part, part date, "extracts a subfield from the date"); @@ -810,6 +889,12 @@ scalar_expr!( date, "converts a string to a `Timestamp(Microseconds, None)`" ); +scalar_expr!( + ToTimestampNanos, + to_timestamp_nanos, + date, + "converts a string to a `Timestamp(Nanoseconds, None)`" +); scalar_expr!( ToTimestampSeconds, to_timestamp_seconds, @@ -840,6 +925,16 @@ scalar_expr!( ); scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type"); +scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings"); +scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the substring from str before count occurrences of the delimiter"); +scalar_expr!(FindInSet, find_in_set, str strlist, "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings"); + +scalar_expr!( + Struct, + struct_fun, + val, + "returns a vector of fields from the struct" +); /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. pub fn case(expr: Expr) -> CaseBuilder { @@ -851,11 +946,18 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder { CaseBuilder::new(None, vec![when], vec![then], None) } -/// Creates a new UDF with a specific signature and specific return type. -/// This is a helper function to create a new UDF. -/// The function `create_udf` returns a subset of all possible `ScalarFunction`: -/// * the UDF has a fixed return type -/// * the UDF has a fixed signature (e.g. [f64, f64]) +/// Convenience method to create a new user defined scalar function (UDF) with a +/// specific signature and specific return type. +/// +/// Note this function does not expose all available features of [`ScalarUDF`], +/// such as +/// +/// * computing return types based on input types +/// * multiple [`Signature`]s +/// * aliases +/// +/// See [`ScalarUDF`] for details and examples on how to use the full +/// functionality. pub fn create_udf( name: &str, input_types: Vec, @@ -863,13 +965,76 @@ pub fn create_udf( volatility: Volatility, fun: ScalarFunctionImplementation, ) -> ScalarUDF { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - ScalarUDF::new( + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + ScalarUDF::from(SimpleScalarUDF::new( name, - &Signature::exact(input_types, volatility), - &return_type, - &fun, - ) + input_types, + return_type, + volatility, + fun, + )) +} + +/// Implements [`ScalarUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleScalarUDF { + name: String, + signature: Signature, + return_type: DataType, + fun: ScalarFunctionImplementation, +} + +impl Debug for SimpleScalarUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("ScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl SimpleScalarUDF { + /// Create a new `SimpleScalarUDF` from a name, input types, return type and + /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_types: Vec, + return_type: DataType, + volatility: Volatility, + fun: ScalarFunctionImplementation, + ) -> Self { + let name = name.into(); + let signature = Signature::exact(input_types, volatility); + Self { + name, + signature, + return_type, + fun, + } + } +} + +impl ScalarUDFImpl for SimpleScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + (self.fun)(args) + } } /// Creates a new UDAF with a specific signature, state type and return type. @@ -905,13 +1070,77 @@ pub fn create_udwf( volatility: Volatility, partition_evaluator_factory: PartitionEvaluatorFactory, ) -> WindowUDF { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - WindowUDF::new( + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + WindowUDF::from(SimpleWindowUDF::new( name, - &Signature::exact(vec![input_type], volatility), - &return_type, - &partition_evaluator_factory, - ) + input_type, + return_type, + volatility, + partition_evaluator_factory, + )) +} + +/// Implements [`WindowUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleWindowUDF { + name: String, + signature: Signature, + return_type: DataType, + partition_evaluator_factory: PartitionEvaluatorFactory, +} + +impl Debug for SimpleWindowUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("WindowUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("return_type", &"") + .field("partition_evaluator_factory", &"") + .finish() + } +} + +impl SimpleWindowUDF { + /// Create a new `SimpleWindowUDF` from a name, input types, return type and + /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_type: DataType, + return_type: DataType, + volatility: Volatility, + partition_evaluator_factory: PartitionEvaluatorFactory, + ) -> Self { + let name = name.into(); + let signature = Signature::exact([input_type].to_vec(), volatility); + Self { + name, + signature, + return_type, + partition_evaluator_factory, + } + } +} + +impl WindowUDFImpl for SimpleWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn partition_evaluator(&self) -> Result> { + (self.partition_evaluator_factory)() + } } /// Calls a named built in function @@ -931,7 +1160,7 @@ pub fn call_fn(name: impl AsRef, args: Vec) -> Result { #[cfg(test)] mod test { use super::*; - use crate::lit; + use crate::{lit, ScalarFunctionDefinition}; #[test] fn filter_is_null_and_is_not_null() { @@ -946,8 +1175,10 @@ mod test { macro_rules! test_unary_scalar_expr { ($ENUM:ident, $FUNC:ident) => {{ - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - $FUNC(col("tableA.a")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = $FUNC(col("tableA.a")) { let name = built_in_function::BuiltinScalarFunction::$ENUM; assert_eq!(name, fun); @@ -959,42 +1190,42 @@ mod test { } macro_rules! test_scalar_expr { - ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { - let expected = [$(stringify!($arg)),*]; - let result = $FUNC( + ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { + let expected = [$(stringify!($arg)),*]; + let result = $FUNC( + $( + col(stringify!($arg.to_string())) + ),* + ); + if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result { + let name = built_in_function::BuiltinScalarFunction::$ENUM; + assert_eq!(name, fun); + assert_eq!(expected.len(), args.len()); + } else { + assert!(false, "unexpected: {:?}", result); + } + }; +} + + macro_rules! test_nary_scalar_expr { + ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { + let expected = [$(stringify!($arg)),*]; + let result = $FUNC( + vec![ $( col(stringify!($arg.to_string())) ),* - ); - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = result { - let name = built_in_function::BuiltinScalarFunction::$ENUM; - assert_eq!(name, fun); - assert_eq!(expected.len(), args.len()); - } else { - assert!(false, "unexpected: {:?}", result); - } - }; - } - - macro_rules! test_nary_scalar_expr { - ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { - let expected = [$(stringify!($arg)),*]; - let result = $FUNC( - vec![ - $( - col(stringify!($arg.to_string())) - ),* - ] - ); - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = result { - let name = built_in_function::BuiltinScalarFunction::$ENUM; - assert_eq!(name, fun); - assert_eq!(expected.len(), args.len()); - } else { - assert!(false, "unexpected: {:?}", result); - } - }; - } + ] + ); + if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result { + let name = built_in_function::BuiltinScalarFunction::$ENUM; + assert_eq!(name, fun); + assert_eq!(expected.len(), args.len()); + } else { + assert!(false, "unexpected: {:?}", result); + } + }; +} #[test] fn scalar_function_definitions() { @@ -1097,6 +1328,8 @@ mod test { test_scalar_expr!(FromUnixtime, from_unixtime, unixtime); test_scalar_expr!(ArrayAppend, array_append, array, element); + test_scalar_expr!(ArraySort, array_sort, array, desc, null_first); + test_scalar_expr!(ArrayPopFront, array_pop_front, array); test_scalar_expr!(ArrayPopBack, array_pop_back, array); test_unary_scalar_expr!(ArrayDims, array_dims); test_scalar_expr!(ArrayLength, array_length, array, dimension); @@ -1116,11 +1349,20 @@ mod test { test_nary_scalar_expr!(MakeArray, array, input); test_unary_scalar_expr!(ArrowTypeof, arrow_typeof); + test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); + test_nary_scalar_expr!(OverLay, overlay, string, characters, position); + test_scalar_expr!(Levenshtein, levenshtein, string1, string2); + test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count); + test_scalar_expr!(FindInSet, find_in_set, string, stringlist); } #[test] fn uuid_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = uuid() { + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = uuid() + { let name = BuiltinScalarFunction::Uuid; assert_eq!(name, fun); assert_eq!(0, args.len()); @@ -1131,8 +1373,10 @@ mod test { #[test] fn digest_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - digest(col("tableA.a"), lit("md5")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = digest(col("tableA.a"), lit("md5")) { let name = BuiltinScalarFunction::Digest; assert_eq!(name, fun); @@ -1144,8 +1388,10 @@ mod test { #[test] fn encode_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - encode(col("tableA.a"), lit("base64")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = encode(col("tableA.a"), lit("base64")) { let name = BuiltinScalarFunction::Encode; assert_eq!(name, fun); @@ -1157,8 +1403,10 @@ mod test { #[test] fn decode_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - decode(col("tableA.a"), lit("hex")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = decode(col("tableA.a"), lit("hex")) { let name = BuiltinScalarFunction::Decode; assert_eq!(name, fun); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 9651b377c5bd..ba21d09f0619 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,18 +17,19 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, AggregateUDF, Alias, BinaryExpr, Cast, GetFieldAccess, - GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, ScalarUDF, Sort, - TryCast, WindowFunction, + AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast, + GetFieldAccess, GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, + ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; use crate::field_util::GetFieldAccessSchema; use crate::type_coercion::binary::get_result_type; -use crate::{LogicalPlan, Projection, Subquery}; +use crate::type_coercion::functions::data_types; +use crate::{utils, LogicalPlan, Projection, Subquery}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; use datafusion_common::{ - internal_err, plan_err, Column, DFField, DFSchema, DataFusionError, ExprSchema, - Result, + internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, + DataFusionError, ExprSchema, Result, }; use std::collections::HashMap; use std::sync::Arc; @@ -81,20 +82,34 @@ impl ExprSchemable for Expr { Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) - } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let data_types = args + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let arg_data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - - fun.return_type(&data_types) + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + // verify that input data types is consistent with function's `TypeSignature` + data_types(&arg_data_types, &fun.signature()).map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{fun}"), + fun.signature(), + &arg_data_types, + ) + ) + })?; + + fun.return_type(&arg_data_types) + } + ScalarFunctionDefinition::UDF(fun) => { + Ok(fun.return_type(&arg_data_types)?) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + } } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { let data_types = args @@ -103,19 +118,22 @@ impl ExprSchemable for Expr { .collect::>>()?; fun.return_type(&data_types) } - Expr::AggregateFunction(AggregateFunction { fun, args, .. }) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - fun.return_type(&data_types) - } - Expr::AggregateUDF(AggregateUDF { fun, args, .. }) => { + Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + fun.return_type(&data_types) + } + AggregateFunctionDefinition::UDF(fun) => { + Ok(fun.return_type(&data_types)?) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + } } Expr::Not(_) | Expr::IsNull(_) @@ -141,18 +159,16 @@ impl ExprSchemable for Expr { Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean), Expr::Placeholder(Placeholder { data_type, .. }) => { data_type.clone().ok_or_else(|| { - DataFusionError::Plan( - "Placeholder type could not be resolved".to_owned(), - ) + plan_datafusion_err!("Placeholder type could not be resolved") }) } - Expr::Wildcard => { + Expr::Wildcard { qualifier } => { // Wildcard do not really have a type and do not appear in projections - Ok(DataType::Null) + match qualifier { + Some(_) => internal_err!("QualifiedWildcard expressions are not valid in a logical query plan"), + None => Ok(DataType::Null) + } } - Expr::QualifiedWildcard { .. } => internal_err!( - "QualifiedWildcard expressions are not valid in a logical query plan" - ), Expr::GroupingSet(_) => { // grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) @@ -232,10 +248,8 @@ impl ExprSchemable for Expr { Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::ScalarFunction(..) - | Expr::ScalarUDF(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } | Expr::Placeholder(_) => Ok(true), Expr::IsNull(_) | Expr::IsNotNull(_) @@ -259,13 +273,17 @@ impl ExprSchemable for Expr { | Expr::SimilarTo(Like { expr, pattern, .. }) => { Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?) } - Expr::Wildcard => internal_err!( + Expr::Wildcard { .. } => internal_err!( "Wildcard expressions are not valid in a logical query plan" ), - Expr::QualifiedWildcard { .. } => internal_err!( - "QualifiedWildcard expressions are not valid in a logical query plan" - ), Expr::GetIndexedField(GetIndexedField { expr, field }) => { + // If schema is nested, check if parent is nullable + // if it is, return early + if let Expr::Column(col) = expr.as_ref() { + if input_schema.nullable(col)? { + return Ok(true); + } + } field_for_index(expr, field, input_schema).map(|x| x.is_nullable()) } Expr::GroupingSet(_) => { @@ -297,6 +315,13 @@ impl ExprSchemable for Expr { self.nullable(input_schema)?, ) .with_metadata(self.metadata(input_schema)?)), + Expr::Alias(Alias { relation, name, .. }) => Ok(DFField::new( + relation.clone(), + name, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + ) + .with_metadata(self.metadata(input_schema)?)), _ => Ok(DFField::new_unqualified( &self.display_name()?, self.get_type(input_schema)?, @@ -393,8 +418,8 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result {{ @@ -530,6 +555,27 @@ mod tests { assert_eq!(&meta, expr.to_field(&schema).unwrap().metadata()); } + #[test] + fn test_nested_schema_nullability() { + let fields = DFField::new( + Some(TableReference::Bare { + table: "table_name".into(), + }), + "parent", + DataType::Struct(Fields::from(vec![Field::new( + "child", + DataType::Int64, + false, + )])), + true, + ); + + let schema = DFSchema::new_with_metadata(vec![fields], HashMap::new()).unwrap(); + + let expr = col("parent").field("child"); + assert!(expr.nullable(&schema).unwrap()); + } + #[derive(Debug)] struct MockExprSchema { nullable: bool, diff --git a/datafusion/expr/src/field_util.rs b/datafusion/expr/src/field_util.rs index 23260ea9c270..3829a2086b26 100644 --- a/datafusion/expr/src/field_util.rs +++ b/datafusion/expr/src/field_util.rs @@ -18,7 +18,9 @@ //! Utility functions for complex field access use arrow::datatypes::{DataType, Field}; -use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + plan_datafusion_err, plan_err, DataFusionError, Result, ScalarValue, +}; /// Types of the field access expression of a nested type, such as `Field` or `List` pub enum GetFieldAccessSchema { @@ -45,6 +47,19 @@ impl GetFieldAccessSchema { match self { Self::NamedStructField{ name } => { match (data_type, name) { + (DataType::Map(fields, _), _) => { + match fields.data_type() { + DataType::Struct(fields) if fields.len() == 2 => { + // Arrow's MapArray is essentially a ListArray of structs with two columns. They are + // often named "key", and "value", but we don't require any specific naming here; + // instead, we assume that the second columnis the "value" column both here and in + // execution. + let value_field = fields.get(1).expect("fields should have exactly two members"); + Ok(Field::new("map", value_field.data_type().clone(), true)) + }, + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + } + } (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { if s.is_empty() { plan_err!( @@ -52,13 +67,13 @@ impl GetFieldAccessSchema { ) } else { let field = fields.iter().find(|f| f.name() == s); - field.ok_or(DataFusionError::Plan(format!("Field {s} not found in struct"))).map(|f| f.as_ref().clone()) + field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| f.as_ref().clone()) } } (DataType::Struct(_), _) => plan_err!( "Only utf8 strings are valid as an indexed field in a struct" ), - (other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + (other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `Struct`, or `Map` types, got {other}"), } } Self::ListIndex{ key_dt } => { diff --git a/datafusion/expr/src/interval_arithmetic.rs b/datafusion/expr/src/interval_arithmetic.rs new file mode 100644 index 000000000000..5d34fe91c3ac --- /dev/null +++ b/datafusion/expr/src/interval_arithmetic.rs @@ -0,0 +1,3307 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Interval arithmetic library + +use std::borrow::Borrow; +use std::fmt::{self, Display, Formatter}; +use std::ops::{AddAssign, SubAssign}; + +use crate::type_coercion::binary::get_result_type; +use crate::Operator; + +use arrow::compute::{cast_with_options, CastOptions}; +use arrow::datatypes::DataType; +use arrow::datatypes::{IntervalUnit, TimeUnit}; +use datafusion_common::rounding::{alter_fp_rounding_mode, next_down, next_up}; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; + +macro_rules! get_extreme_value { + ($extreme:ident, $value:expr) => { + match $value { + DataType::UInt8 => ScalarValue::UInt8(Some(u8::$extreme)), + DataType::UInt16 => ScalarValue::UInt16(Some(u16::$extreme)), + DataType::UInt32 => ScalarValue::UInt32(Some(u32::$extreme)), + DataType::UInt64 => ScalarValue::UInt64(Some(u64::$extreme)), + DataType::Int8 => ScalarValue::Int8(Some(i8::$extreme)), + DataType::Int16 => ScalarValue::Int16(Some(i16::$extreme)), + DataType::Int32 => ScalarValue::Int32(Some(i32::$extreme)), + DataType::Int64 => ScalarValue::Int64(Some(i64::$extreme)), + DataType::Float32 => ScalarValue::Float32(Some(f32::$extreme)), + DataType::Float64 => ScalarValue::Float64(Some(f64::$extreme)), + DataType::Duration(TimeUnit::Second) => { + ScalarValue::DurationSecond(Some(i64::$extreme)) + } + DataType::Duration(TimeUnit::Millisecond) => { + ScalarValue::DurationMillisecond(Some(i64::$extreme)) + } + DataType::Duration(TimeUnit::Microsecond) => { + ScalarValue::DurationMicrosecond(Some(i64::$extreme)) + } + DataType::Duration(TimeUnit::Nanosecond) => { + ScalarValue::DurationNanosecond(Some(i64::$extreme)) + } + DataType::Timestamp(TimeUnit::Second, _) => { + ScalarValue::TimestampSecond(Some(i64::$extreme), None) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + ScalarValue::TimestampMillisecond(Some(i64::$extreme), None) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + ScalarValue::TimestampMicrosecond(Some(i64::$extreme), None) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + ScalarValue::TimestampNanosecond(Some(i64::$extreme), None) + } + DataType::Interval(IntervalUnit::YearMonth) => { + ScalarValue::IntervalYearMonth(Some(i32::$extreme)) + } + DataType::Interval(IntervalUnit::DayTime) => { + ScalarValue::IntervalDayTime(Some(i64::$extreme)) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + ScalarValue::IntervalMonthDayNano(Some(i128::$extreme)) + } + _ => unreachable!(), + } + }; +} + +macro_rules! value_transition { + ($bound:ident, $direction:expr, $value:expr) => { + match $value { + UInt8(Some(value)) if value == u8::$bound => UInt8(None), + UInt16(Some(value)) if value == u16::$bound => UInt16(None), + UInt32(Some(value)) if value == u32::$bound => UInt32(None), + UInt64(Some(value)) if value == u64::$bound => UInt64(None), + Int8(Some(value)) if value == i8::$bound => Int8(None), + Int16(Some(value)) if value == i16::$bound => Int16(None), + Int32(Some(value)) if value == i32::$bound => Int32(None), + Int64(Some(value)) if value == i64::$bound => Int64(None), + Float32(Some(value)) if value == f32::$bound => Float32(None), + Float64(Some(value)) if value == f64::$bound => Float64(None), + DurationSecond(Some(value)) if value == i64::$bound => DurationSecond(None), + DurationMillisecond(Some(value)) if value == i64::$bound => { + DurationMillisecond(None) + } + DurationMicrosecond(Some(value)) if value == i64::$bound => { + DurationMicrosecond(None) + } + DurationNanosecond(Some(value)) if value == i64::$bound => { + DurationNanosecond(None) + } + TimestampSecond(Some(value), tz) if value == i64::$bound => { + TimestampSecond(None, tz) + } + TimestampMillisecond(Some(value), tz) if value == i64::$bound => { + TimestampMillisecond(None, tz) + } + TimestampMicrosecond(Some(value), tz) if value == i64::$bound => { + TimestampMicrosecond(None, tz) + } + TimestampNanosecond(Some(value), tz) if value == i64::$bound => { + TimestampNanosecond(None, tz) + } + IntervalYearMonth(Some(value)) if value == i32::$bound => { + IntervalYearMonth(None) + } + IntervalDayTime(Some(value)) if value == i64::$bound => IntervalDayTime(None), + IntervalMonthDayNano(Some(value)) if value == i128::$bound => { + IntervalMonthDayNano(None) + } + _ => next_value_helper::<$direction>($value), + } + }; +} + +/// The `Interval` type represents a closed interval used for computing +/// reliable bounds for mathematical expressions. +/// +/// Conventions: +/// +/// 1. **Closed bounds**: The interval always encompasses its endpoints. We +/// accommodate operations resulting in open intervals by incrementing or +/// decrementing the interval endpoint value to its successor/predecessor. +/// +/// 2. **Unbounded endpoints**: If the `lower` or `upper` bounds are indeterminate, +/// they are labeled as *unbounded*. This is represented using a `NULL`. +/// +/// 3. **Overflow handling**: If the `lower` or `upper` endpoints exceed their +/// limits after any operation, they either become unbounded or they are fixed +/// to the maximum/minimum value of the datatype, depending on the direction +/// of the overflowing endpoint, opting for the safer choice. +/// +/// 4. **Floating-point special cases**: +/// - `INF` values are converted to `NULL`s while constructing an interval to +/// ensure consistency, with other data types. +/// - `NaN` (Not a Number) results are conservatively result in unbounded +/// endpoints. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Interval { + lower: ScalarValue, + upper: ScalarValue, +} + +/// This macro handles the `NaN` and `INF` floating point values. +/// +/// - `NaN` values are always converted to unbounded i.e. `NULL` values. +/// - For lower bounds: +/// - A `NEG_INF` value is converted to a `NULL`. +/// - An `INF` value is conservatively converted to the maximum representable +/// number for the floating-point type in question. In this case, converting +/// to `NULL` doesn't make sense as it would be interpreted as a `NEG_INF`. +/// - For upper bounds: +/// - An `INF` value is converted to a `NULL`. +/// - An `NEG_INF` value is conservatively converted to the minimum representable +/// number for the floating-point type in question. In this case, converting +/// to `NULL` doesn't make sense as it would be interpreted as an `INF`. +macro_rules! handle_float_intervals { + ($scalar_type:ident, $primitive_type:ident, $lower:expr, $upper:expr) => {{ + let lower = match $lower { + ScalarValue::$scalar_type(Some(l_val)) + if l_val == $primitive_type::NEG_INFINITY || l_val.is_nan() => + { + ScalarValue::$scalar_type(None) + } + ScalarValue::$scalar_type(Some(l_val)) + if l_val == $primitive_type::INFINITY => + { + ScalarValue::$scalar_type(Some($primitive_type::MAX)) + } + value @ ScalarValue::$scalar_type(Some(_)) => value, + _ => ScalarValue::$scalar_type(None), + }; + + let upper = match $upper { + ScalarValue::$scalar_type(Some(r_val)) + if r_val == $primitive_type::INFINITY || r_val.is_nan() => + { + ScalarValue::$scalar_type(None) + } + ScalarValue::$scalar_type(Some(r_val)) + if r_val == $primitive_type::NEG_INFINITY => + { + ScalarValue::$scalar_type(Some($primitive_type::MIN)) + } + value @ ScalarValue::$scalar_type(Some(_)) => value, + _ => ScalarValue::$scalar_type(None), + }; + + Interval { lower, upper } + }}; +} + +/// Ordering floating-point numbers according to their binary representations +/// contradicts with their natural ordering. Floating-point number ordering +/// after unsigned integer transmutation looks like: +/// +/// ```text +/// 0, 1, 2, 3, ..., MAX, -0, -1, -2, ..., -MAX +/// ``` +/// +/// This macro applies a one-to-one map that fixes the ordering above. +macro_rules! map_floating_point_order { + ($value:expr, $ty:ty) => {{ + let num_bits = std::mem::size_of::<$ty>() * 8; + let sign_bit = 1 << (num_bits - 1); + if $value & sign_bit == sign_bit { + // Negative numbers: + !$value + } else { + // Positive numbers: + $value | sign_bit + } + }}; +} + +impl Interval { + /// Attempts to create a new `Interval` from the given lower and upper bounds. + /// + /// # Notes + /// + /// This constructor creates intervals in a "canonical" form where: + /// - **Boolean intervals**: + /// - Unboundedness (`NULL`) for boolean endpoints is converted to `false` + /// for lower and `true` for upper bounds. + /// - **Floating-point intervals**: + /// - Floating-point endpoints with `NaN`, `INF`, or `NEG_INF` are converted + /// to `NULL`s. + pub fn try_new(lower: ScalarValue, upper: ScalarValue) -> Result { + if lower.data_type() != upper.data_type() { + return internal_err!("Endpoints of an Interval should have the same type"); + } + + let interval = Self::new(lower, upper); + + if interval.lower.is_null() + || interval.upper.is_null() + || interval.lower <= interval.upper + { + Ok(interval) + } else { + internal_err!( + "Interval's lower bound {} is greater than the upper bound {}", + interval.lower, + interval.upper + ) + } + } + + /// Only for internal usage. Responsible for standardizing booleans and + /// floating-point values, as well as fixing NaNs. It doesn't validate + /// the given bounds for ordering, or verify that they have the same data + /// type. For its user-facing counterpart and more details, see + /// [`Interval::try_new`]. + fn new(lower: ScalarValue, upper: ScalarValue) -> Self { + if let ScalarValue::Boolean(lower_bool) = lower { + let ScalarValue::Boolean(upper_bool) = upper else { + // We are sure that upper and lower bounds have the same type. + unreachable!(); + }; + // Standardize boolean interval endpoints: + Self { + lower: ScalarValue::Boolean(Some(lower_bool.unwrap_or(false))), + upper: ScalarValue::Boolean(Some(upper_bool.unwrap_or(true))), + } + } + // Standardize floating-point endpoints: + else if lower.data_type() == DataType::Float32 { + handle_float_intervals!(Float32, f32, lower, upper) + } else if lower.data_type() == DataType::Float64 { + handle_float_intervals!(Float64, f64, lower, upper) + } else { + // Other data types do not require standardization: + Self { lower, upper } + } + } + + /// Convenience function to create a new `Interval` from the given (optional) + /// bounds, for use in tests only. Absence of either endpoint indicates + /// unboundedness on that side. See [`Interval::try_new`] for more information. + pub fn make(lower: Option, upper: Option) -> Result + where + ScalarValue: From>, + { + Self::try_new(ScalarValue::from(lower), ScalarValue::from(upper)) + } + + /// Creates an unbounded interval from both sides if the datatype supported. + pub fn make_unbounded(data_type: &DataType) -> Result { + let unbounded_endpoint = ScalarValue::try_from(data_type)?; + Ok(Self::new(unbounded_endpoint.clone(), unbounded_endpoint)) + } + + /// Returns a reference to the lower bound. + pub fn lower(&self) -> &ScalarValue { + &self.lower + } + + /// Returns a reference to the upper bound. + pub fn upper(&self) -> &ScalarValue { + &self.upper + } + + /// Converts this `Interval` into its boundary scalar values. It's useful + /// when you need to work with the individual bounds directly. + pub fn into_bounds(self) -> (ScalarValue, ScalarValue) { + (self.lower, self.upper) + } + + /// This function returns the data type of this interval. + pub fn data_type(&self) -> DataType { + let lower_type = self.lower.data_type(); + let upper_type = self.upper.data_type(); + + // There must be no way to create an interval whose endpoints have + // different types. + assert!( + lower_type == upper_type, + "Interval bounds have different types: {lower_type} != {upper_type}" + ); + lower_type + } + + /// Casts this interval to `data_type` using `cast_options`. + pub fn cast_to( + &self, + data_type: &DataType, + cast_options: &CastOptions, + ) -> Result { + Self::try_new( + cast_scalar_value(&self.lower, data_type, cast_options)?, + cast_scalar_value(&self.upper, data_type, cast_options)?, + ) + } + + pub const CERTAINLY_FALSE: Self = Self { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + }; + + pub const UNCERTAIN: Self = Self { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(true)), + }; + + pub const CERTAINLY_TRUE: Self = Self { + lower: ScalarValue::Boolean(Some(true)), + upper: ScalarValue::Boolean(Some(true)), + }; + + /// Decide if this interval is certainly greater than, possibly greater than, + /// or can't be greater than `other` by returning `[true, true]`, + /// `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn gt>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + internal_err!( + "Only intervals with the same data type are comparable, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ) + } else if !(self.upper.is_null() || rhs.lower.is_null()) + && self.upper <= rhs.lower + { + // Values in this interval are certainly less than or equal to + // those in the given interval. + Ok(Self::CERTAINLY_FALSE) + } else if !(self.lower.is_null() || rhs.upper.is_null()) + && (self.lower > rhs.upper) + { + // Values in this interval are certainly greater than those in the + // given interval. + Ok(Self::CERTAINLY_TRUE) + } else { + // All outcomes are possible. + Ok(Self::UNCERTAIN) + } + } + + /// Decide if this interval is certainly greater than or equal to, possibly + /// greater than or equal to, or can't be greater than or equal to `other` + /// by returning `[true, true]`, `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn gt_eq>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + internal_err!( + "Only intervals with the same data type are comparable, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ) + } else if !(self.lower.is_null() || rhs.upper.is_null()) + && self.lower >= rhs.upper + { + // Values in this interval are certainly greater than or equal to + // those in the given interval. + Ok(Self::CERTAINLY_TRUE) + } else if !(self.upper.is_null() || rhs.lower.is_null()) + && (self.upper < rhs.lower) + { + // Values in this interval are certainly less than those in the + // given interval. + Ok(Self::CERTAINLY_FALSE) + } else { + // All outcomes are possible. + Ok(Self::UNCERTAIN) + } + } + + /// Decide if this interval is certainly less than, possibly less than, or + /// can't be less than `other` by returning `[true, true]`, `[false, true]` + /// or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn lt>(&self, other: T) -> Result { + other.borrow().gt(self) + } + + /// Decide if this interval is certainly less than or equal to, possibly + /// less than or equal to, or can't be less than or equal to `other` by + /// returning `[true, true]`, `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn lt_eq>(&self, other: T) -> Result { + other.borrow().gt_eq(self) + } + + /// Decide if this interval is certainly equal to, possibly equal to, or + /// can't be equal to `other` by returning `[true, true]`, `[false, true]` + /// or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn equal>(&self, other: T) -> Result { + let rhs = other.borrow(); + if get_result_type(&self.data_type(), &Operator::Eq, &rhs.data_type()).is_err() { + internal_err!( + "Interval data types must be compatible for equality checks, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ) + } else if !self.lower.is_null() + && (self.lower == self.upper) + && (rhs.lower == rhs.upper) + && (self.lower == rhs.lower) + { + Ok(Self::CERTAINLY_TRUE) + } else if self.intersect(rhs)?.is_none() { + Ok(Self::CERTAINLY_FALSE) + } else { + Ok(Self::UNCERTAIN) + } + } + + /// Compute the logical conjunction of this (boolean) interval with the + /// given boolean interval. + pub(crate) fn and>(&self, other: T) -> Result { + let rhs = other.borrow(); + match (&self.lower, &self.upper, &rhs.lower, &rhs.upper) { + ( + &ScalarValue::Boolean(Some(self_lower)), + &ScalarValue::Boolean(Some(self_upper)), + &ScalarValue::Boolean(Some(other_lower)), + &ScalarValue::Boolean(Some(other_upper)), + ) => { + let lower = self_lower && other_lower; + let upper = self_upper && other_upper; + + Ok(Self { + lower: ScalarValue::Boolean(Some(lower)), + upper: ScalarValue::Boolean(Some(upper)), + }) + } + _ => internal_err!("Incompatible data types for logical conjunction"), + } + } + + /// Compute the logical negation of this (boolean) interval. + pub(crate) fn not(&self) -> Result { + if self.data_type().ne(&DataType::Boolean) { + internal_err!("Cannot apply logical negation to a non-boolean interval") + } else if self == &Self::CERTAINLY_TRUE { + Ok(Self::CERTAINLY_FALSE) + } else if self == &Self::CERTAINLY_FALSE { + Ok(Self::CERTAINLY_TRUE) + } else { + Ok(Self::UNCERTAIN) + } + } + + /// Compute the intersection of this interval with the given interval. + /// If the intersection is empty, return `None`. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn intersect>(&self, other: T) -> Result> { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + return internal_err!( + "Only intervals with the same data type are intersectable, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + // If it is evident that the result is an empty interval, short-circuit + // and directly return `None`. + if (!(self.lower.is_null() || rhs.upper.is_null()) && self.lower > rhs.upper) + || (!(self.upper.is_null() || rhs.lower.is_null()) && self.upper < rhs.lower) + { + return Ok(None); + } + + let lower = max_of_bounds(&self.lower, &rhs.lower); + let upper = min_of_bounds(&self.upper, &rhs.upper); + + // New lower and upper bounds must always construct a valid interval. + assert!( + (lower.is_null() || upper.is_null() || (lower <= upper)), + "The intersection of two intervals can not be an invalid interval" + ); + + Ok(Some(Self { lower, upper })) + } + + /// Decide if this interval certainly contains, possibly contains, or can't + /// contain a [`ScalarValue`] (`other`) by returning `[true, true]`, + /// `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn contains_value>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + return internal_err!( + "Data types must be compatible for containment checks, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + } + + // We only check the upper bound for a `None` value because `None` + // values are less than `Some` values according to Rust. + Ok(&self.lower <= rhs && (self.upper.is_null() || rhs <= &self.upper)) + } + + /// Decide if this interval is a superset of, overlaps with, or + /// disjoint with `other` by returning `[true, true]`, `[false, true]` or + /// `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn contains>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + return internal_err!( + "Interval data types must match for containment checks, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + match self.intersect(rhs)? { + Some(intersection) => { + if &intersection == rhs { + Ok(Self::CERTAINLY_TRUE) + } else { + Ok(Self::UNCERTAIN) + } + } + None => Ok(Self::CERTAINLY_FALSE), + } + } + + /// Add the given interval (`other`) to this interval. Say we have intervals + /// `[a1, b1]` and `[a2, b2]`, then their sum is `[a1 + a2, b1 + b2]`. Note + /// that this represents all possible values the sum can take if one can + /// choose single values arbitrarily from each of the operands. + pub fn add>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = get_result_type(&self.data_type(), &Operator::Plus, &rhs.data_type())?; + + Ok(Self::new( + add_bounds::(&dt, &self.lower, &rhs.lower), + add_bounds::(&dt, &self.upper, &rhs.upper), + )) + } + + /// Subtract the given interval (`other`) from this interval. Say we have + /// intervals `[a1, b1]` and `[a2, b2]`, then their difference is + /// `[a1 - b2, b1 - a2]`. Note that this represents all possible values the + /// difference can take if one can choose single values arbitrarily from + /// each of the operands. + pub fn sub>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = get_result_type(&self.data_type(), &Operator::Minus, &rhs.data_type())?; + + Ok(Self::new( + sub_bounds::(&dt, &self.lower, &rhs.upper), + sub_bounds::(&dt, &self.upper, &rhs.lower), + )) + } + + /// Multiply the given interval (`other`) with this interval. Say we have + /// intervals `[a1, b1]` and `[a2, b2]`, then their product is `[min(a1 * a2, + /// a1 * b2, b1 * a2, b1 * b2), max(a1 * a2, a1 * b2, b1 * a2, b1 * b2)]`. + /// Note that this represents all possible values the product can take if + /// one can choose single values arbitrarily from each of the operands. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn mul>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = if self.data_type().eq(&rhs.data_type()) { + self.data_type() + } else { + return internal_err!( + "Intervals must have the same data type for multiplication, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + let zero = ScalarValue::new_zero(&dt)?; + + let result = match ( + self.contains_value(&zero)?, + rhs.contains_value(&zero)?, + dt.is_unsigned_integer(), + ) { + (true, true, false) => mul_helper_multi_zero_inclusive(&dt, self, rhs), + (true, false, false) => { + mul_helper_single_zero_inclusive(&dt, self, rhs, zero) + } + (false, true, false) => { + mul_helper_single_zero_inclusive(&dt, rhs, self, zero) + } + _ => mul_helper_zero_exclusive(&dt, self, rhs, zero), + }; + Ok(result) + } + + /// Divide this interval by the given interval (`other`). Say we have intervals + /// `[a1, b1]` and `[a2, b2]`, then their division is `[a1, b1] * [1 / b2, 1 / a2]` + /// if `0 ∉ [a2, b2]` and `[NEG_INF, INF]` otherwise. Note that this represents + /// all possible values the quotient can take if one can choose single values + /// arbitrarily from each of the operands. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + /// + /// **TODO**: Once interval sets are supported, cases where the divisor contains + /// zero should result in an interval set, not the universal set. + pub fn div>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = if self.data_type().eq(&rhs.data_type()) { + self.data_type() + } else { + return internal_err!( + "Intervals must have the same data type for division, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + let zero = ScalarValue::new_zero(&dt)?; + // We want 0 to be approachable from both negative and positive sides. + let zero_point = match &dt { + DataType::Float32 | DataType::Float64 => Self::new(zero.clone(), zero), + _ => Self::new(prev_value(zero.clone()), next_value(zero)), + }; + + // Exit early with an unbounded interval if zero is strictly inside the + // right hand side: + if rhs.contains(&zero_point)? == Self::CERTAINLY_TRUE && !dt.is_unsigned_integer() + { + Self::make_unbounded(&dt) + } + // At this point, we know that only one endpoint of the right hand side + // can be zero. + else if self.contains(&zero_point)? == Self::CERTAINLY_TRUE + && !dt.is_unsigned_integer() + { + Ok(div_helper_lhs_zero_inclusive(&dt, self, rhs, &zero_point)) + } else { + Ok(div_helper_zero_exclusive(&dt, self, rhs, &zero_point)) + } + } + + /// Returns the cardinality of this interval, which is the number of all + /// distinct points inside it. This function returns `None` if: + /// - The interval is unbounded from either side, or + /// - Cardinality calculations for the datatype in question is not + /// implemented yet, or + /// - An overflow occurs during the calculation: This case can only arise + /// when the calculated cardinality does not fit in an `u64`. + pub fn cardinality(&self) -> Option { + let data_type = self.data_type(); + if data_type.is_integer() { + self.upper.distance(&self.lower).map(|diff| diff as u64) + } else if data_type.is_floating() { + // Negative numbers are sorted in the reverse order. To + // always have a positive difference after the subtraction, + // we perform following transformation: + match (&self.lower, &self.upper) { + // Exploit IEEE 754 ordering properties to calculate the correct + // cardinality in all cases (including subnormals). + ( + ScalarValue::Float32(Some(lower)), + ScalarValue::Float32(Some(upper)), + ) => { + let lower_bits = map_floating_point_order!(lower.to_bits(), u32); + let upper_bits = map_floating_point_order!(upper.to_bits(), u32); + Some((upper_bits - lower_bits) as u64) + } + ( + ScalarValue::Float64(Some(lower)), + ScalarValue::Float64(Some(upper)), + ) => { + let lower_bits = map_floating_point_order!(lower.to_bits(), u64); + let upper_bits = map_floating_point_order!(upper.to_bits(), u64); + let count = upper_bits - lower_bits; + (count != u64::MAX).then_some(count) + } + _ => None, + } + } else { + // Cardinality calculations are not implemented for this data type yet: + None + } + .map(|result| result + 1) + } +} + +impl Display for Interval { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "[{}, {}]", self.lower, self.upper) + } +} + +/// Applies the given binary operator the `lhs` and `rhs` arguments. +pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { + match *op { + Operator::Eq => lhs.equal(rhs), + Operator::NotEq => lhs.equal(rhs)?.not(), + Operator::Gt => lhs.gt(rhs), + Operator::GtEq => lhs.gt_eq(rhs), + Operator::Lt => lhs.lt(rhs), + Operator::LtEq => lhs.lt_eq(rhs), + Operator::And => lhs.and(rhs), + Operator::Plus => lhs.add(rhs), + Operator::Minus => lhs.sub(rhs), + Operator::Multiply => lhs.mul(rhs), + Operator::Divide => lhs.div(rhs), + _ => internal_err!("Interval arithmetic does not support the operator {op}"), + } +} + +/// Helper function used for adding the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn add_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + if lhs.is_null() || rhs.is_null() { + return ScalarValue::try_from(dt).unwrap(); + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.add_checked(rhs)) + } + _ => lhs.add_checked(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Plus, lhs, rhs)) +} + +/// Helper function used for subtracting the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn sub_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + if lhs.is_null() || rhs.is_null() { + return ScalarValue::try_from(dt).unwrap(); + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.sub_checked(rhs)) + } + _ => lhs.sub_checked(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Minus, lhs, rhs)) +} + +/// Helper function used for multiplying the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn mul_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + if lhs.is_null() || rhs.is_null() { + return ScalarValue::try_from(dt).unwrap(); + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.mul_checked(rhs)) + } + _ => lhs.mul_checked(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Multiply, lhs, rhs)) +} + +/// Helper function used for dividing the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn div_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + let zero = ScalarValue::new_zero(dt).unwrap(); + + if (lhs.is_null() || rhs.eq(&zero)) || (dt.is_unsigned_integer() && rhs.is_null()) { + return ScalarValue::try_from(dt).unwrap(); + } else if rhs.is_null() { + return zero; + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.div(rhs)) + } + _ => lhs.div(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Divide, lhs, rhs)) +} + +/// This function handles cases where an operation results in an overflow. Such +/// results are converted to an *unbounded endpoint* if: +/// - We are calculating an upper bound and we have a positive overflow. +/// - We are calculating a lower bound and we have a negative overflow. +/// Otherwise; the function sets the endpoint as: +/// - The minimum representable number with the given datatype (`dt`) if +/// we are calculating an upper bound and we have a negative overflow. +/// - The maximum representable number with the given datatype (`dt`) if +/// we are calculating a lower bound and we have a positive overflow. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, `op` is supported by +/// interval library, and the following interval creation is standardized with +/// `Interval::new`. +fn handle_overflow( + dt: &DataType, + op: Operator, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + let zero = ScalarValue::new_zero(dt).unwrap(); + let positive_sign = match op { + Operator::Multiply | Operator::Divide => { + lhs.lt(&zero) && rhs.lt(&zero) || lhs.gt(&zero) && rhs.gt(&zero) + } + Operator::Plus => lhs.ge(&zero), + Operator::Minus => lhs.ge(rhs), + _ => { + unreachable!() + } + }; + match (UPPER, positive_sign) { + (true, true) | (false, false) => ScalarValue::try_from(dt).unwrap(), + (true, false) => { + get_extreme_value!(MIN, dt) + } + (false, true) => { + get_extreme_value!(MAX, dt) + } + } +} + +// This function should remain private since it may corrupt the an interval if +// used without caution. +fn next_value(value: ScalarValue) -> ScalarValue { + use ScalarValue::*; + value_transition!(MAX, true, value) +} + +// This function should remain private since it may corrupt the an interval if +// used without caution. +fn prev_value(value: ScalarValue) -> ScalarValue { + use ScalarValue::*; + value_transition!(MIN, false, value) +} + +trait OneTrait: Sized + std::ops::Add + std::ops::Sub { + fn one() -> Self; +} +macro_rules! impl_OneTrait{ + ($($m:ty),*) => {$( impl OneTrait for $m { fn one() -> Self { 1 as $m } })*} +} +impl_OneTrait! {u8, u16, u32, u64, i8, i16, i32, i64, i128} + +/// This function either increments or decrements its argument, depending on +/// the `INC` value (where a `true` value corresponds to the increment). +fn increment_decrement( + mut value: T, +) -> T { + if INC { + value.add_assign(T::one()); + } else { + value.sub_assign(T::one()); + } + value +} + +/// This function returns the next/previous value depending on the `INC` value. +/// If `true`, it returns the next value; otherwise it returns the previous value. +fn next_value_helper(value: ScalarValue) -> ScalarValue { + use ScalarValue::*; + match value { + // f32/f64::NEG_INF/INF and f32/f64::NaN values should not emerge at this point. + Float32(Some(val)) => { + assert!(val.is_finite(), "Non-standardized floating point usage"); + Float32(Some(if INC { next_up(val) } else { next_down(val) })) + } + Float64(Some(val)) => { + assert!(val.is_finite(), "Non-standardized floating point usage"); + Float64(Some(if INC { next_up(val) } else { next_down(val) })) + } + Int8(Some(val)) => Int8(Some(increment_decrement::(val))), + Int16(Some(val)) => Int16(Some(increment_decrement::(val))), + Int32(Some(val)) => Int32(Some(increment_decrement::(val))), + Int64(Some(val)) => Int64(Some(increment_decrement::(val))), + UInt8(Some(val)) => UInt8(Some(increment_decrement::(val))), + UInt16(Some(val)) => UInt16(Some(increment_decrement::(val))), + UInt32(Some(val)) => UInt32(Some(increment_decrement::(val))), + UInt64(Some(val)) => UInt64(Some(increment_decrement::(val))), + DurationSecond(Some(val)) => { + DurationSecond(Some(increment_decrement::(val))) + } + DurationMillisecond(Some(val)) => { + DurationMillisecond(Some(increment_decrement::(val))) + } + DurationMicrosecond(Some(val)) => { + DurationMicrosecond(Some(increment_decrement::(val))) + } + DurationNanosecond(Some(val)) => { + DurationNanosecond(Some(increment_decrement::(val))) + } + TimestampSecond(Some(val), tz) => { + TimestampSecond(Some(increment_decrement::(val)), tz) + } + TimestampMillisecond(Some(val), tz) => { + TimestampMillisecond(Some(increment_decrement::(val)), tz) + } + TimestampMicrosecond(Some(val), tz) => { + TimestampMicrosecond(Some(increment_decrement::(val)), tz) + } + TimestampNanosecond(Some(val), tz) => { + TimestampNanosecond(Some(increment_decrement::(val)), tz) + } + IntervalYearMonth(Some(val)) => { + IntervalYearMonth(Some(increment_decrement::(val))) + } + IntervalDayTime(Some(val)) => { + IntervalDayTime(Some(increment_decrement::(val))) + } + IntervalMonthDayNano(Some(val)) => { + IntervalMonthDayNano(Some(increment_decrement::(val))) + } + _ => value, // Unbounded values return without change. + } +} + +/// Returns the greater of the given interval bounds. Assumes that a `NULL` +/// value represents `NEG_INF`. +fn max_of_bounds(first: &ScalarValue, second: &ScalarValue) -> ScalarValue { + if !first.is_null() && (second.is_null() || first >= second) { + first.clone() + } else { + second.clone() + } +} + +/// Returns the lesser of the given interval bounds. Assumes that a `NULL` +/// value represents `INF`. +fn min_of_bounds(first: &ScalarValue, second: &ScalarValue) -> ScalarValue { + if !first.is_null() && (second.is_null() || first <= second) { + first.clone() + } else { + second.clone() + } +} + +/// This function updates the given intervals by enforcing (i.e. propagating) +/// the inequality `left > right` (or the `left >= right` inequality, if `strict` +/// is `true`). +/// +/// Returns a `Result` wrapping an `Option` containing the tuple of resulting +/// intervals. If the comparison is infeasible, returns `None`. +/// +/// Example usage: +/// ``` +/// use datafusion_common::DataFusionError; +/// use datafusion_expr::interval_arithmetic::{satisfy_greater, Interval}; +/// +/// let left = Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?; +/// let right = Interval::make(Some(500.0_f32), Some(2000.0_f32))?; +/// let strict = false; +/// assert_eq!( +/// satisfy_greater(&left, &right, strict)?, +/// Some(( +/// Interval::make(Some(500.0_f32), Some(1000.0_f32))?, +/// Interval::make(Some(500.0_f32), Some(1000.0_f32))? +/// )) +/// ); +/// Ok::<(), DataFusionError>(()) +/// ``` +/// +/// NOTE: This function only works with intervals of the same data type. +/// Attempting to compare intervals of different data types will lead +/// to an error. +pub fn satisfy_greater( + left: &Interval, + right: &Interval, + strict: bool, +) -> Result> { + if left.data_type().ne(&right.data_type()) { + return internal_err!( + "Intervals must have the same data type, lhs:{}, rhs:{}", + left.data_type(), + right.data_type() + ); + } + + if !left.upper.is_null() && left.upper <= right.lower { + if !strict && left.upper == right.lower { + // Singleton intervals: + return Ok(Some(( + Interval::new(left.upper.clone(), left.upper.clone()), + Interval::new(left.upper.clone(), left.upper.clone()), + ))); + } else { + // Left-hand side: <--======----0------------> + // Right-hand side: <------------0--======----> + // No intersection, infeasible to propagate: + return Ok(None); + } + } + + // Only the lower bound of left hand side and the upper bound of the right + // hand side can change after propagating the greater-than operation. + let new_left_lower = if left.lower.is_null() || left.lower <= right.lower { + if strict { + next_value(right.lower.clone()) + } else { + right.lower.clone() + } + } else { + left.lower.clone() + }; + // Below code is asymmetric relative to the above if statement, because + // `None` compares less than `Some` in Rust. + let new_right_upper = if right.upper.is_null() + || (!left.upper.is_null() && left.upper <= right.upper) + { + if strict { + prev_value(left.upper.clone()) + } else { + left.upper.clone() + } + } else { + right.upper.clone() + }; + + Ok(Some(( + Interval::new(new_left_lower, left.upper.clone()), + Interval::new(right.lower.clone(), new_right_upper), + ))) +} + +/// Multiplies two intervals that both contain zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their product (whose data type is known to be `dt`). It is +/// specifically designed to handle intervals that contain zero within their +/// ranges. Returns an error if the multiplication of bounds fails. +/// +/// ```text +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <-------=====0=====-------> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn mul_helper_multi_zero_inclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, +) -> Interval { + if lhs.lower.is_null() + || lhs.upper.is_null() + || rhs.lower.is_null() + || rhs.upper.is_null() + { + return Interval::make_unbounded(dt).unwrap(); + } + // Since unbounded cases are handled above, we can safely + // use the utility functions here to eliminate code duplication. + let lower = min_of_bounds( + &mul_bounds::(dt, &lhs.lower, &rhs.upper), + &mul_bounds::(dt, &rhs.lower, &lhs.upper), + ); + let upper = max_of_bounds( + &mul_bounds::(dt, &lhs.upper, &rhs.upper), + &mul_bounds::(dt, &lhs.lower, &rhs.lower), + ); + // There is no possibility to create an invalid interval. + Interval::new(lower, upper) +} + +/// Multiplies two intervals when only left-hand side interval contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their product (whose data type is known to be `dt`). This function +/// serves as a subroutine that handles the specific case when only `lhs` contains +/// zero within its range. The interval not containing zero, i.e. rhs, can lie +/// on either side of zero. Returns an error if the multiplication of bounds fails. +/// +/// ``` text +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn mul_helper_single_zero_inclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero: ScalarValue, +) -> Interval { + // With the following interval bounds, there is no possibility to create an invalid interval. + if rhs.upper <= zero && !rhs.upper.is_null() { + // <-------=====0=====-------> + // <--======----0------------> + let lower = mul_bounds::(dt, &lhs.upper, &rhs.lower); + let upper = mul_bounds::(dt, &lhs.lower, &rhs.lower); + Interval::new(lower, upper) + } else { + // <-------=====0=====-------> + // <------------0--======----> + let lower = mul_bounds::(dt, &lhs.lower, &rhs.upper); + let upper = mul_bounds::(dt, &lhs.upper, &rhs.upper); + Interval::new(lower, upper) + } +} + +/// Multiplies two intervals when neither of them contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their product (whose data type is known to be `dt`). It is +/// specifically designed to handle intervals that do not contain zero within +/// their ranges. Returns an error if the multiplication of bounds fails. +/// +/// ``` text +/// Left-hand side: <--======----0------------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <--======----0------------> +/// Right-hand side: <------------0--======----> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn mul_helper_zero_exclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero: ScalarValue, +) -> Interval { + let (lower, upper) = match ( + lhs.upper <= zero && !lhs.upper.is_null(), + rhs.upper <= zero && !rhs.upper.is_null(), + ) { + // With the following interval bounds, there is no possibility to create an invalid interval. + (true, true) => ( + // <--======----0------------> + // <--======----0------------> + mul_bounds::(dt, &lhs.upper, &rhs.upper), + mul_bounds::(dt, &lhs.lower, &rhs.lower), + ), + (true, false) => ( + // <--======----0------------> + // <------------0--======----> + mul_bounds::(dt, &lhs.lower, &rhs.upper), + mul_bounds::(dt, &lhs.upper, &rhs.lower), + ), + (false, true) => ( + // <------------0--======----> + // <--======----0------------> + mul_bounds::(dt, &rhs.lower, &lhs.upper), + mul_bounds::(dt, &rhs.upper, &lhs.lower), + ), + (false, false) => ( + // <------------0--======----> + // <------------0--======----> + mul_bounds::(dt, &lhs.lower, &rhs.lower), + mul_bounds::(dt, &lhs.upper, &rhs.upper), + ), + }; + Interval::new(lower, upper) +} + +/// Divides the left-hand side interval by the right-hand side interval when +/// the former contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their quotient (whose data type is known to be `dt`). This function +/// serves as a subroutine that handles the specific case when only `lhs` contains +/// zero within its range. Returns an error if the division of bounds fails. +/// +/// ``` text +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn div_helper_lhs_zero_inclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero_point: &Interval, +) -> Interval { + // With the following interval bounds, there is no possibility to create an invalid interval. + if rhs.upper <= zero_point.lower && !rhs.upper.is_null() { + // <-------=====0=====-------> + // <--======----0------------> + let lower = div_bounds::(dt, &lhs.upper, &rhs.upper); + let upper = div_bounds::(dt, &lhs.lower, &rhs.upper); + Interval::new(lower, upper) + } else { + // <-------=====0=====-------> + // <------------0--======----> + let lower = div_bounds::(dt, &lhs.lower, &rhs.lower); + let upper = div_bounds::(dt, &lhs.upper, &rhs.lower); + Interval::new(lower, upper) + } +} + +/// Divides the left-hand side interval by the right-hand side interval when +/// neither interval contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their quotient (whose data type is known to be `dt`). It is +/// specifically designed to handle intervals that do not contain zero within +/// their ranges. Returns an error if the division of bounds fails. +/// +/// ``` text +/// Left-hand side: <--======----0------------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <--======----0------------> +/// Right-hand side: <------------0--======----> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn div_helper_zero_exclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero_point: &Interval, +) -> Interval { + let (lower, upper) = match ( + lhs.upper <= zero_point.lower && !lhs.upper.is_null(), + rhs.upper <= zero_point.lower && !rhs.upper.is_null(), + ) { + // With the following interval bounds, there is no possibility to create an invalid interval. + (true, true) => ( + // <--======----0------------> + // <--======----0------------> + div_bounds::(dt, &lhs.upper, &rhs.lower), + div_bounds::(dt, &lhs.lower, &rhs.upper), + ), + (true, false) => ( + // <--======----0------------> + // <------------0--======----> + div_bounds::(dt, &lhs.lower, &rhs.lower), + div_bounds::(dt, &lhs.upper, &rhs.upper), + ), + (false, true) => ( + // <------------0--======----> + // <--======----0------------> + div_bounds::(dt, &lhs.upper, &rhs.upper), + div_bounds::(dt, &lhs.lower, &rhs.lower), + ), + (false, false) => ( + // <------------0--======----> + // <------------0--======----> + div_bounds::(dt, &lhs.lower, &rhs.upper), + div_bounds::(dt, &lhs.upper, &rhs.lower), + ), + }; + Interval::new(lower, upper) +} + +/// This function computes the selectivity of an operation by computing the +/// cardinality ratio of the given input/output intervals. If this can not be +/// calculated for some reason, it returns `1.0` meaning fully selective (no +/// filtering). +pub fn cardinality_ratio(initial_interval: &Interval, final_interval: &Interval) -> f64 { + match (final_interval.cardinality(), initial_interval.cardinality()) { + (Some(final_interval), Some(initial_interval)) => { + (final_interval as f64) / (initial_interval as f64) + } + _ => 1.0, + } +} + +/// Cast scalar value to the given data type using an arrow kernel. +fn cast_scalar_value( + value: &ScalarValue, + data_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let cast_array = cast_with_options(&value.to_array()?, data_type, cast_options)?; + ScalarValue::try_from_array(&cast_array, 0) +} + +/// An [Interval] that also tracks null status using a boolean interval. +/// +/// This represents values that may be in a particular range or be null. +/// +/// # Examples +/// +/// ``` +/// use arrow::datatypes::DataType; +/// use datafusion_common::ScalarValue; +/// use datafusion_expr::interval_arithmetic::Interval; +/// use datafusion_expr::interval_arithmetic::NullableInterval; +/// +/// // [1, 2) U {NULL} +/// let maybe_null = NullableInterval::MaybeNull { +/// values: Interval::try_new( +/// ScalarValue::Int32(Some(1)), +/// ScalarValue::Int32(Some(2)), +/// ).unwrap(), +/// }; +/// +/// // (0, ∞) +/// let not_null = NullableInterval::NotNull { +/// values: Interval::try_new( +/// ScalarValue::Int32(Some(0)), +/// ScalarValue::Int32(None), +/// ).unwrap(), +/// }; +/// +/// // {NULL} +/// let null_interval = NullableInterval::Null { datatype: DataType::Int32 }; +/// +/// // {4} +/// let single_value = NullableInterval::from(ScalarValue::Int32(Some(4))); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum NullableInterval { + /// The value is always null. This is typed so it can be used in physical + /// expressions, which don't do type coercion. + Null { datatype: DataType }, + /// The value may or may not be null. If it is non-null, its is within the + /// specified range. + MaybeNull { values: Interval }, + /// The value is definitely not null, and is within the specified range. + NotNull { values: Interval }, +} + +impl Display for NullableInterval { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Null { .. } => write!(f, "NullableInterval: {{NULL}}"), + Self::MaybeNull { values } => { + write!(f, "NullableInterval: {} U {{NULL}}", values) + } + Self::NotNull { values } => write!(f, "NullableInterval: {}", values), + } + } +} + +impl From for NullableInterval { + /// Create an interval that represents a single value. + fn from(value: ScalarValue) -> Self { + if value.is_null() { + Self::Null { + datatype: value.data_type(), + } + } else { + Self::NotNull { + values: Interval { + lower: value.clone(), + upper: value, + }, + } + } + } +} + +impl NullableInterval { + /// Get the values interval, or None if this interval is definitely null. + pub fn values(&self) -> Option<&Interval> { + match self { + Self::Null { .. } => None, + Self::MaybeNull { values } | Self::NotNull { values } => Some(values), + } + } + + /// Get the data type + pub fn data_type(&self) -> DataType { + match self { + Self::Null { datatype } => datatype.clone(), + Self::MaybeNull { values } | Self::NotNull { values } => values.data_type(), + } + } + + /// Return true if the value is definitely true (and not null). + pub fn is_certainly_true(&self) -> bool { + match self { + Self::Null { .. } | Self::MaybeNull { .. } => false, + Self::NotNull { values } => values == &Interval::CERTAINLY_TRUE, + } + } + + /// Return true if the value is definitely false (and not null). + pub fn is_certainly_false(&self) -> bool { + match self { + Self::Null { .. } => false, + Self::MaybeNull { .. } => false, + Self::NotNull { values } => values == &Interval::CERTAINLY_FALSE, + } + } + + /// Perform logical negation on a boolean nullable interval. + fn not(&self) -> Result { + match self { + Self::Null { datatype } => Ok(Self::Null { + datatype: datatype.clone(), + }), + Self::MaybeNull { values } => Ok(Self::MaybeNull { + values: values.not()?, + }), + Self::NotNull { values } => Ok(Self::NotNull { + values: values.not()?, + }), + } + } + + /// Apply the given operator to this interval and the given interval. + /// + /// # Examples + /// + /// ``` + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::Operator; + /// use datafusion_expr::interval_arithmetic::Interval; + /// use datafusion_expr::interval_arithmetic::NullableInterval; + /// + /// // 4 > 3 -> true + /// let lhs = NullableInterval::from(ScalarValue::Int32(Some(4))); + /// let rhs = NullableInterval::from(ScalarValue::Int32(Some(3))); + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// assert_eq!(result, NullableInterval::from(ScalarValue::Boolean(Some(true)))); + /// + /// // [1, 3) > NULL -> NULL + /// let lhs = NullableInterval::NotNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(3)), + /// ).unwrap(), + /// }; + /// let rhs = NullableInterval::from(ScalarValue::Int32(None)); + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// assert_eq!(result.single_value(), Some(ScalarValue::Boolean(None))); + /// + /// // [1, 3] > [2, 4] -> [false, true] + /// let lhs = NullableInterval::NotNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(3)), + /// ).unwrap(), + /// }; + /// let rhs = NullableInterval::NotNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(2)), + /// ScalarValue::Int32(Some(4)), + /// ).unwrap(), + /// }; + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// // Both inputs are valid (non-null), so result must be non-null + /// assert_eq!(result, NullableInterval::NotNull { + /// // Uncertain whether inequality is true or false + /// values: Interval::UNCERTAIN, + /// }); + /// ``` + pub fn apply_operator(&self, op: &Operator, rhs: &Self) -> Result { + match op { + Operator::IsDistinctFrom => { + let values = match (self, rhs) { + // NULL is distinct from NULL -> False + (Self::Null { .. }, Self::Null { .. }) => Interval::CERTAINLY_FALSE, + // x is distinct from y -> x != y, + // if at least one of them is never null. + (Self::NotNull { .. }, _) | (_, Self::NotNull { .. }) => { + let lhs_values = self.values(); + let rhs_values = rhs.values(); + match (lhs_values, rhs_values) { + (Some(lhs_values), Some(rhs_values)) => { + lhs_values.equal(rhs_values)?.not()? + } + (Some(_), None) | (None, Some(_)) => Interval::CERTAINLY_TRUE, + (None, None) => unreachable!("Null case handled above"), + } + } + _ => Interval::UNCERTAIN, + }; + // IsDistinctFrom never returns null. + Ok(Self::NotNull { values }) + } + Operator::IsNotDistinctFrom => self + .apply_operator(&Operator::IsDistinctFrom, rhs) + .map(|i| i.not())?, + _ => { + if let (Some(left_values), Some(right_values)) = + (self.values(), rhs.values()) + { + let values = apply_operator(op, left_values, right_values)?; + match (self, rhs) { + (Self::NotNull { .. }, Self::NotNull { .. }) => { + Ok(Self::NotNull { values }) + } + _ => Ok(Self::MaybeNull { values }), + } + } else if op.is_comparison_operator() { + Ok(Self::Null { + datatype: DataType::Boolean, + }) + } else { + Ok(Self::Null { + datatype: self.data_type(), + }) + } + } + } + } + + /// Decide if this interval is a superset of, overlaps with, or + /// disjoint with `other` by returning `[true, true]`, `[false, true]` or + /// `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn contains>(&self, other: T) -> Result { + let rhs = other.borrow(); + if let (Some(left_values), Some(right_values)) = (self.values(), rhs.values()) { + left_values + .contains(right_values) + .map(|values| match (self, rhs) { + (Self::NotNull { .. }, Self::NotNull { .. }) => { + Self::NotNull { values } + } + _ => Self::MaybeNull { values }, + }) + } else { + Ok(Self::Null { + datatype: DataType::Boolean, + }) + } + } + + /// If the interval has collapsed to a single value, return that value. + /// Otherwise, returns `None`. + /// + /// # Examples + /// + /// ``` + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::interval_arithmetic::Interval; + /// use datafusion_expr::interval_arithmetic::NullableInterval; + /// + /// let interval = NullableInterval::from(ScalarValue::Int32(Some(4))); + /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4)))); + /// + /// let interval = NullableInterval::from(ScalarValue::Int32(None)); + /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(None))); + /// + /// let interval = NullableInterval::MaybeNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(4)), + /// ).unwrap(), + /// }; + /// assert_eq!(interval.single_value(), None); + /// ``` + pub fn single_value(&self) -> Option { + match self { + Self::Null { datatype } => { + Some(ScalarValue::try_from(datatype).unwrap_or(ScalarValue::Null)) + } + Self::MaybeNull { values } | Self::NotNull { values } + if values.lower == values.upper && !values.lower.is_null() => + { + Some(values.lower.clone()) + } + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + use crate::interval_arithmetic::{next_value, prev_value, satisfy_greater, Interval}; + + use arrow::datatypes::DataType; + use datafusion_common::{Result, ScalarValue}; + + #[test] + fn test_next_prev_value() -> Result<()> { + let zeros = vec![ + ScalarValue::new_zero(&DataType::UInt8)?, + ScalarValue::new_zero(&DataType::UInt16)?, + ScalarValue::new_zero(&DataType::UInt32)?, + ScalarValue::new_zero(&DataType::UInt64)?, + ScalarValue::new_zero(&DataType::Int8)?, + ScalarValue::new_zero(&DataType::Int16)?, + ScalarValue::new_zero(&DataType::Int32)?, + ScalarValue::new_zero(&DataType::Int64)?, + ]; + let ones = vec![ + ScalarValue::new_one(&DataType::UInt8)?, + ScalarValue::new_one(&DataType::UInt16)?, + ScalarValue::new_one(&DataType::UInt32)?, + ScalarValue::new_one(&DataType::UInt64)?, + ScalarValue::new_one(&DataType::Int8)?, + ScalarValue::new_one(&DataType::Int16)?, + ScalarValue::new_one(&DataType::Int32)?, + ScalarValue::new_one(&DataType::Int64)?, + ]; + zeros.into_iter().zip(ones).for_each(|(z, o)| { + assert_eq!(next_value(z.clone()), o); + assert_eq!(prev_value(o), z); + }); + + let values = vec![ + ScalarValue::new_zero(&DataType::Float32)?, + ScalarValue::new_zero(&DataType::Float64)?, + ]; + let eps = vec![ + ScalarValue::Float32(Some(1e-6)), + ScalarValue::Float64(Some(1e-6)), + ]; + values.into_iter().zip(eps).for_each(|(value, eps)| { + assert!(next_value(value.clone()) + .sub(value.clone()) + .unwrap() + .lt(&eps)); + assert!(value + .clone() + .sub(prev_value(value.clone())) + .unwrap() + .lt(&eps)); + assert_ne!(next_value(value.clone()), value); + assert_ne!(prev_value(value.clone()), value); + }); + + let min_max = vec![ + ( + ScalarValue::UInt64(Some(u64::MIN)), + ScalarValue::UInt64(Some(u64::MAX)), + ), + ( + ScalarValue::Int8(Some(i8::MIN)), + ScalarValue::Int8(Some(i8::MAX)), + ), + ( + ScalarValue::Float32(Some(f32::MIN)), + ScalarValue::Float32(Some(f32::MAX)), + ), + ( + ScalarValue::Float64(Some(f64::MIN)), + ScalarValue::Float64(Some(f64::MAX)), + ), + ]; + let inf = vec![ + ScalarValue::UInt64(None), + ScalarValue::Int8(None), + ScalarValue::Float32(None), + ScalarValue::Float64(None), + ]; + min_max.into_iter().zip(inf).for_each(|((min, max), inf)| { + assert_eq!(next_value(max.clone()), inf); + assert_ne!(prev_value(max.clone()), max); + assert_ne!(prev_value(max.clone()), inf); + + assert_eq!(prev_value(min.clone()), inf); + assert_ne!(next_value(min.clone()), min); + assert_ne!(next_value(min.clone()), inf); + + assert_eq!(next_value(inf.clone()), inf); + assert_eq!(prev_value(inf.clone()), inf); + }); + + Ok(()) + } + + #[test] + fn test_new_interval() -> Result<()> { + use ScalarValue::*; + + let cases = vec![ + ( + (Boolean(None), Boolean(Some(false))), + Boolean(Some(false)), + Boolean(Some(false)), + ), + ( + (Boolean(Some(false)), Boolean(None)), + Boolean(Some(false)), + Boolean(Some(true)), + ), + ( + (Boolean(Some(false)), Boolean(Some(true))), + Boolean(Some(false)), + Boolean(Some(true)), + ), + ( + (UInt16(Some(u16::MAX)), UInt16(None)), + UInt16(Some(u16::MAX)), + UInt16(None), + ), + ( + (Int16(None), Int16(Some(-1000))), + Int16(None), + Int16(Some(-1000)), + ), + ( + (Float32(Some(f32::MAX)), Float32(Some(f32::MAX))), + Float32(Some(f32::MAX)), + Float32(Some(f32::MAX)), + ), + ( + (Float32(Some(f32::NAN)), Float32(Some(f32::MIN))), + Float32(None), + Float32(Some(f32::MIN)), + ), + ( + ( + Float64(Some(f64::NEG_INFINITY)), + Float64(Some(f64::INFINITY)), + ), + Float64(None), + Float64(None), + ), + ]; + for (inputs, lower, upper) in cases { + let result = Interval::try_new(inputs.0, inputs.1)?; + assert_eq!(result.clone().lower(), &lower); + assert_eq!(result.upper(), &upper); + } + + let invalid_intervals = vec![ + (Float32(Some(f32::INFINITY)), Float32(Some(100_f32))), + (Float64(Some(0_f64)), Float64(Some(f64::NEG_INFINITY))), + (Boolean(Some(true)), Boolean(Some(false))), + (Int32(Some(1000)), Int32(Some(-2000))), + (UInt64(Some(1)), UInt64(Some(0))), + ]; + for (lower, upper) in invalid_intervals { + Interval::try_new(lower, upper).expect_err( + "Given parameters should have given an invalid interval error", + ); + } + + Ok(()) + } + + #[test] + fn test_make_unbounded() -> Result<()> { + use ScalarValue::*; + + let unbounded_cases = vec![ + (DataType::Boolean, Boolean(Some(false)), Boolean(Some(true))), + (DataType::UInt8, UInt8(None), UInt8(None)), + (DataType::UInt16, UInt16(None), UInt16(None)), + (DataType::UInt32, UInt32(None), UInt32(None)), + (DataType::UInt64, UInt64(None), UInt64(None)), + (DataType::Int8, Int8(None), Int8(None)), + (DataType::Int16, Int16(None), Int16(None)), + (DataType::Int32, Int32(None), Int32(None)), + (DataType::Int64, Int64(None), Int64(None)), + (DataType::Float32, Float32(None), Float32(None)), + (DataType::Float64, Float64(None), Float64(None)), + ]; + for (dt, lower, upper) in unbounded_cases { + let inf = Interval::make_unbounded(&dt)?; + assert_eq!(inf.clone().lower(), &lower); + assert_eq!(inf.upper(), &upper); + } + + Ok(()) + } + + #[test] + fn gt_lt_test() -> Result<()> { + let exactly_gt_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(999_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(None, Some(999_i64))?, + ), + ( + Interval::make(Some(501_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(500_i64))?, + ), + ( + Interval::make(Some(-1000_i64), Some(1000_i64))?, + Interval::make(None, Some(-1500_i64))?, + ), + ( + Interval::try_new( + next_value(ScalarValue::Float32(Some(0.0))), + next_value(ScalarValue::Float32(Some(0.0))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0))), + prev_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in exactly_gt_cases { + assert_eq!(first.gt(second.clone())?, Interval::CERTAINLY_TRUE); + assert_eq!(second.lt(first)?, Interval::CERTAINLY_TRUE); + } + + let possibly_gt_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::try_new( + ScalarValue::Float32(Some(0.0_f32)), + next_value(ScalarValue::Float32(Some(0.0_f32))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0_f32))), + ScalarValue::Float32(Some(-1.0_f32)), + )?, + ), + ]; + for (first, second) in possibly_gt_cases { + assert_eq!(first.gt(second.clone())?, Interval::UNCERTAIN); + assert_eq!(second.lt(first)?, Interval::UNCERTAIN); + } + + let not_gt_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0_f32))), + ScalarValue::Float32(Some(0.0_f32)), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + ScalarValue::Float32(Some(-1.0_f32)), + next_value(ScalarValue::Float32(Some(-1.0_f32))), + )?, + ), + ]; + for (first, second) in not_gt_cases { + assert_eq!(first.gt(second.clone())?, Interval::CERTAINLY_FALSE); + assert_eq!(second.lt(first)?, Interval::CERTAINLY_FALSE); + } + + Ok(()) + } + + #[test] + fn gteq_lteq_test() -> Result<()> { + let exactly_gteq_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(500_i64))?, + ), + ( + Interval::make(Some(-1000_i64), Some(1000_i64))?, + Interval::make(None, Some(-1500_i64))?, + ), + ( + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::try_new( + ScalarValue::Float32(Some(-1.0)), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0))), + ScalarValue::Float32(Some(-1.0)), + )?, + ), + ]; + for (first, second) in exactly_gteq_cases { + assert_eq!(first.gt_eq(second.clone())?, Interval::CERTAINLY_TRUE); + assert_eq!(second.lt_eq(first)?, Interval::CERTAINLY_TRUE); + } + + let possibly_gteq_cases = vec![ + ( + Interval::make(Some(999_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1001_i64))?, + ), + ( + Interval::make(Some(0_i64), None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0))), + ScalarValue::Float32(Some(0.0)), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0_f32))), + next_value(ScalarValue::Float32(Some(-1.0_f32))), + )?, + ), + ]; + for (first, second) in possibly_gteq_cases { + assert_eq!(first.gt_eq(second.clone())?, Interval::UNCERTAIN); + assert_eq!(second.lt_eq(first)?, Interval::UNCERTAIN); + } + + let not_gteq_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(2000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(999_i64))?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1001_i64), Some(1500_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0_f32))), + prev_value(ScalarValue::Float32(Some(0.0_f32))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + next_value(ScalarValue::Float32(Some(-1.0))), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in not_gteq_cases { + assert_eq!(first.gt_eq(second.clone())?, Interval::CERTAINLY_FALSE); + assert_eq!(second.lt_eq(first)?, Interval::CERTAINLY_FALSE); + } + + Ok(()) + } + + #[test] + fn equal_test() -> Result<()> { + let exactly_eq_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(0_u64), Some(0_u64))?, + Interval::make(Some(0_u64), Some(0_u64))?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + ), + ( + Interval::make(Some(f64::MIN), Some(f64::MIN))?, + Interval::make(Some(f64::MIN), Some(f64::MIN))?, + ), + ]; + for (first, second) in exactly_eq_cases { + assert_eq!(first.equal(second.clone())?, Interval::CERTAINLY_TRUE); + assert_eq!(second.equal(first)?, Interval::CERTAINLY_TRUE); + } + + let possibly_eq_cases = vec![ + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0_i64), Some(0_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(0_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(100.0_f32), Some(200.0_f32))?, + Interval::make(Some(0.0_f32), Some(1000.0_f32))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0))), + ScalarValue::Float32(Some(0.0)), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0))), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in possibly_eq_cases { + assert_eq!(first.equal(second.clone())?, Interval::UNCERTAIN); + assert_eq!(second.equal(first)?, Interval::UNCERTAIN); + } + + let not_eq_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(2000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(999_i64))?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1001_i64), Some(1500_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0))), + prev_value(ScalarValue::Float32(Some(0.0))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + next_value(ScalarValue::Float32(Some(-1.0))), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in not_eq_cases { + assert_eq!(first.equal(second.clone())?, Interval::CERTAINLY_FALSE); + assert_eq!(second.equal(first)?, Interval::CERTAINLY_FALSE); + } + + Ok(()) + } + + #[test] + fn and_test() -> Result<()> { + let cases = vec![ + (false, true, false, false, false, false), + (false, false, false, true, false, false), + (false, true, false, true, false, true), + (false, true, true, true, false, true), + (false, false, false, false, false, false), + (true, true, true, true, true, true), + ]; + + for case in cases { + assert_eq!( + Interval::make(Some(case.0), Some(case.1))? + .and(Interval::make(Some(case.2), Some(case.3))?)?, + Interval::make(Some(case.4), Some(case.5))? + ); + } + Ok(()) + } + + #[test] + fn not_test() -> Result<()> { + let cases = vec![ + (false, true, false, true), + (false, false, true, true), + (true, true, false, false), + ]; + + for case in cases { + assert_eq!( + Interval::make(Some(case.0), Some(case.1))?.not()?, + Interval::make(Some(case.2), Some(case.3))? + ); + } + Ok(()) + } + + #[test] + fn intersect_test() -> Result<()> { + let possible_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make::(None, None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), None)?, + Interval::make(Some(1000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(None, Some(2000_u64))?, + Interval::make(Some(500_u64), None)?, + Interval::make(Some(500_u64), Some(2000_u64))?, + ), + ( + Interval::make(Some(0_u64), Some(0_u64))?, + Interval::make(Some(0_u64), None)?, + Interval::make(Some(0_u64), Some(0_u64))?, + ), + ( + Interval::make(Some(1000.0_f32), None)?, + Interval::make(None, Some(1000.0_f32))?, + Interval::make(Some(1000.0_f32), Some(1000.0_f32))?, + ), + ( + Interval::make(Some(1000.0_f32), Some(1500.0_f32))?, + Interval::make(Some(0.0_f32), Some(1500.0_f32))?, + Interval::make(Some(1000.0_f32), Some(1500.0_f32))?, + ), + ( + Interval::make(Some(-1000.0_f64), Some(1500.0_f64))?, + Interval::make(Some(-1500.0_f64), Some(2000.0_f64))?, + Interval::make(Some(-1000.0_f64), Some(1500.0_f64))?, + ), + ( + Interval::make(Some(16.0_f64), Some(32.0_f64))?, + Interval::make(Some(32.0_f64), Some(64.0_f64))?, + Interval::make(Some(32.0_f64), Some(32.0_f64))?, + ), + ]; + for (first, second, expected) in possible_cases { + assert_eq!(first.intersect(second)?.unwrap(), expected) + } + + let empty_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(999_i64))?, + ), + ( + Interval::make(Some(1500_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1499_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1000_i64))?, + Interval::make(Some(2000_i64), Some(3000_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(1.0))), + prev_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + ), + ( + Interval::try_new( + next_value(ScalarValue::Float32(Some(1.0))), + next_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + ), + ]; + for (first, second) in empty_cases { + assert_eq!(first.intersect(second)?, None) + } + + Ok(()) + } + + #[test] + fn test_contains() -> Result<()> { + let possible_cases = vec![ + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + Interval::CERTAINLY_TRUE, + ), + ( + Interval::make(Some(1500_i64), Some(2000_i64))?, + Interval::make(Some(1501_i64), Some(1999_i64))?, + Interval::CERTAINLY_TRUE, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make::(None, None)?, + Interval::UNCERTAIN, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(500), Some(1500_i64))?, + Interval::UNCERTAIN, + ), + ( + Interval::make(Some(16.0), Some(32.0))?, + Interval::make(Some(32.0), Some(64.0))?, + Interval::UNCERTAIN, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(0_i64))?, + Interval::CERTAINLY_FALSE, + ), + ( + Interval::make(Some(1500_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1499_i64))?, + Interval::CERTAINLY_FALSE, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(1.0))), + prev_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + Interval::CERTAINLY_FALSE, + ), + ( + Interval::try_new( + next_value(ScalarValue::Float32(Some(1.0))), + next_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + Interval::CERTAINLY_FALSE, + ), + ]; + for (first, second, expected) in possible_cases { + assert_eq!(first.contains(second)?, expected) + } + + Ok(()) + } + + #[test] + fn test_add() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(200_i64))?, + Interval::make(None, Some(400_i64))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(200_i64), None)?, + Interval::make(Some(300_i64), None)?, + ), + ( + Interval::make(None, Some(200_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(400_i64))?, + ), + ( + Interval::make(Some(200_i64), None)?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(300_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(-300_i64), Some(150_i64))?, + Interval::make(Some(-200_i64), Some(350_i64))?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(11_f32), Some(11_f32))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(10_f32))?, + // Since rounding mode is up, the result would be much greater than f32::MIN + // (f32::MIN = -3.4_028_235e38, the result is -3.4_028_233e38) + Interval::make( + None, + Some(-340282330000000000000000000000000000000.0_f32), + )?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(-10_f32))?, + Interval::make(None, Some(f32::MIN))?, + ), + ( + Interval::make(Some(1.0), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(100_f64), None)?, + Interval::make(None, Some(200_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(None, Some(100_f64))?, + Interval::make(None, Some(200_f64))?, + Interval::make(None, Some(300_f64))?, + ), + ]; + for case in cases { + let result = case.0.add(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper()) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_sub() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(i32::MAX), Some(i32::MAX))?, + Interval::make(Some(11_i32), Some(11_i32))?, + Interval::make(Some(i32::MAX - 11), Some(i32::MAX - 11))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(200_i64))?, + Interval::make(Some(-100_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(200_i64), None)?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(None, Some(200_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(100_i64))?, + ), + ( + Interval::make(Some(200_i64), None)?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(-300_i64), Some(150_i64))?, + Interval::make(Some(-50_i64), Some(500_i64))?, + ), + ( + Interval::make(Some(i64::MIN), Some(i64::MIN))?, + Interval::make(Some(-10_i64), Some(-10_i64))?, + Interval::make(Some(i64::MIN + 10), Some(i64::MIN + 10))?, + ), + ( + Interval::make(Some(1), Some(i64::MAX))?, + Interval::make(Some(i64::MAX), Some(i64::MAX))?, + Interval::make(Some(1 - i64::MAX), Some(0))?, + ), + ( + Interval::make(Some(i64::MIN), Some(i64::MIN))?, + Interval::make(Some(i64::MAX), Some(i64::MAX))?, + Interval::make(None, Some(i64::MIN))?, + ), + ( + Interval::make(Some(2_u32), Some(10_u32))?, + Interval::make(Some(4_u32), Some(6_u32))?, + Interval::make(None, Some(6_u32))?, + ), + ( + Interval::make(Some(2_u32), Some(10_u32))?, + Interval::make(Some(20_u32), Some(30_u32))?, + Interval::make(None, Some(0_u32))?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(10_f32))?, + // Since rounding mode is up, the result would be much larger than f32::MIN + // (f32::MIN = -3.4_028_235e38, the result is -3.4_028_233e38) + Interval::make( + None, + Some(-340282330000000000000000000000000000000.0_f32), + )?, + ), + ( + Interval::make(Some(100_f64), None)?, + Interval::make(None, Some(200_f64))?, + Interval::make(Some(-100_f64), None)?, + ), + ( + Interval::make(None, Some(100_f64))?, + Interval::make(None, Some(200_f64))?, + Interval::make::(None, None)?, + ), + ]; + for case in cases { + let result = case.0.sub(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper(),) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_mul() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(None, Some(2_i64))?, + Interval::make(None, Some(4_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(2_i64), None)?, + Interval::make(Some(2_i64), None)?, + ), + ( + Interval::make(None, Some(2_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(None, Some(4_i64))?, + ), + ( + Interval::make(Some(2_i64), None)?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(2_i64), None)?, + ), + ( + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-3_i64), Some(15_i64))?, + Interval::make(Some(-6_i64), Some(30_i64))?, + ), + ( + Interval::make(Some(-0.0), Some(0.0))?, + Interval::make(None, Some(0.0))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(10_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(1_u32), Some(2_u32))?, + Interval::make(Some(0_u32), Some(1_u32))?, + Interval::make(Some(0_u32), Some(2_u32))?, + ), + ( + Interval::make(None, Some(2_u32))?, + Interval::make(Some(0_u32), Some(1_u32))?, + Interval::make(None, Some(2_u32))?, + ), + ( + Interval::make(None, Some(2_u32))?, + Interval::make(Some(1_u32), Some(2_u32))?, + Interval::make(None, Some(4_u32))?, + ), + ( + Interval::make(None, Some(2_u32))?, + Interval::make(Some(1_u32), None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(0_u32), None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(11_f32), Some(11_f32))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(-10_f32))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(1.0), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(None, Some(f32::MIN))?, + ), + ( + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make(Some(0.0_f32), None)?, + ), + ( + Interval::make(Some(1_f64), None)?, + Interval::make(None, Some(2_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(None, Some(1_f64))?, + Interval::make(None, Some(2_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + ), + ( + Interval::make(Some(0.0_f64), Some(0.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(0.0_f64), Some(0.0_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(1.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(-0.0_f64), Some(2.0_f64))?, + ), + ( + Interval::make(Some(0.0_f64), Some(1.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(0.0_f64), Some(2.0_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(1.0_f64))?, + Interval::make(Some(-1_f64), Some(2_f64))?, + Interval::make(Some(-1.0_f64), Some(2.0_f64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make::(None, Some(10.0_f64))?, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + Interval::make::(None, None)?, + ), + ]; + for case in cases { + let result = case.0.mul(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper()) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_div() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(50_i64), Some(200_i64))?, + ), + ( + Interval::make(Some(-200_i64), Some(-100_i64))?, + Interval::make(Some(-2_i64), Some(-1_i64))?, + Interval::make(Some(50_i64), Some(200_i64))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(-2_i64), Some(-1_i64))?, + Interval::make(Some(-200_i64), Some(-50_i64))?, + ), + ( + Interval::make(Some(-200_i64), Some(-100_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-200_i64), Some(-50_i64))?, + ), + ( + Interval::make(Some(-200_i64), Some(100_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-200_i64), Some(100_i64))?, + ), + ( + Interval::make(Some(-100_i64), Some(200_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-100_i64), Some(200_i64))?, + ), + ( + Interval::make(Some(10_i64), Some(20_i64))?, + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-100_i64), Some(200_i64))?, + Interval::make(Some(-1_i64), Some(2_i64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-100_i64), Some(200_i64))?, + Interval::make(Some(-2_i64), Some(1_i64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(1_i64))?, + Interval::make(Some(100_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(0_i64))?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(0_i64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0_i64), Some(1_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(0_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(0_i64))?, + ), + ( + Interval::make(Some(1_u32), Some(2_u32))?, + Interval::make(Some(0_u32), Some(0_u32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(10_u32), Some(20_u32))?, + Interval::make(None, Some(2_u32))?, + Interval::make(Some(5_u32), None)?, + ), + ( + Interval::make(Some(10_u32), Some(20_u32))?, + Interval::make(Some(0_u32), Some(2_u32))?, + Interval::make(Some(5_u32), None)?, + ), + ( + Interval::make(Some(10_u32), Some(20_u32))?, + Interval::make(Some(0_u32), Some(0_u32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(12_u64), Some(48_u64))?, + Interval::make(Some(10_u64), Some(20_u64))?, + Interval::make(Some(0_u64), Some(4_u64))?, + ), + ( + Interval::make(Some(12_u64), Some(48_u64))?, + Interval::make(None, Some(2_u64))?, + Interval::make(Some(6_u64), None)?, + ), + ( + Interval::make(Some(12_u64), Some(48_u64))?, + Interval::make(Some(0_u64), Some(2_u64))?, + Interval::make(Some(6_u64), None)?, + ), + ( + Interval::make(None, Some(48_u64))?, + Interval::make(Some(0_u64), Some(2_u64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(-0.1_f32), Some(0.1_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MIN), None)?, + Interval::make(Some(0.1_f32), Some(0.1_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-10.0_f32), Some(10.0_f32))?, + Interval::make(Some(-0.1_f32), Some(-0.1_f32))?, + Interval::make(Some(-100.0_f32), Some(100.0_f32))?, + ), + ( + Interval::make(Some(-10.0_f32), Some(f32::MAX))?, + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(10.0_f32))?, + Interval::make(Some(1.0_f32), None)?, + Interval::make(Some(f32::MIN), Some(10.0_f32))?, + ), + ( + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + Interval::make(None, Some(-0.0_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(1.0_f32), Some(2.0_f32))?, + Interval::make(Some(0.0_f32), Some(4.0_f32))?, + Interval::make(Some(0.25_f32), None)?, + ), + ( + Interval::make(Some(1.0_f32), Some(2.0_f32))?, + Interval::make(Some(-4.0_f32), Some(-0.0_f32))?, + Interval::make(None, Some(-0.25_f32))?, + ), + ( + Interval::make(Some(-4.0_f64), Some(2.0_f64))?, + Interval::make(Some(10.0_f64), Some(20.0_f64))?, + Interval::make(Some(-0.4_f64), Some(0.2_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + Interval::make(None, Some(-0.0_f64))?, + Interval::make(Some(0.0_f64), None)?, + ), + ( + Interval::make(Some(1.0_f64), Some(2.0_f64))?, + Interval::make::(None, None)?, + Interval::make(Some(0.0_f64), None)?, + ), + ]; + for case in cases { + let result = case.0.div(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper()) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_cardinality_of_intervals() -> Result<()> { + // In IEEE 754 standard for floating-point arithmetic, if we keep the sign and exponent fields same, + // we can represent 4503599627370496+1 different numbers by changing the mantissa + // (4503599627370496 = 2^52, since there are 52 bits in mantissa, and 2^23 = 8388608 for f32). + // TODO: Add tests for non-exponential boundary aligned intervals too. + let distinct_f64 = 4503599627370497; + let distinct_f32 = 8388609; + let intervals = [ + Interval::make(Some(0.25_f64), Some(0.50_f64))?, + Interval::make(Some(0.5_f64), Some(1.0_f64))?, + Interval::make(Some(1.0_f64), Some(2.0_f64))?, + Interval::make(Some(32.0_f64), Some(64.0_f64))?, + Interval::make(Some(-0.50_f64), Some(-0.25_f64))?, + Interval::make(Some(-32.0_f64), Some(-16.0_f64))?, + ]; + for interval in intervals { + assert_eq!(interval.cardinality().unwrap(), distinct_f64); + } + + let intervals = [ + Interval::make(Some(0.25_f32), Some(0.50_f32))?, + Interval::make(Some(-1_f32), Some(-0.5_f32))?, + ]; + for interval in intervals { + assert_eq!(interval.cardinality().unwrap(), distinct_f32); + } + + // The regular logarithmic distribution of floating-point numbers are + // only applicable outside of the `(-phi, phi)` interval where `phi` + // denotes the largest positive subnormal floating-point number. Since + // the following intervals include such subnormal points, we cannot use + // a simple powers-of-two type formula for our expectations. Therefore, + // we manually supply the actual expected cardinality. + let interval = Interval::make(Some(-0.0625), Some(0.0625))?; + assert_eq!(interval.cardinality().unwrap(), 9178336040581070850); + + let interval = Interval::try_new( + ScalarValue::UInt64(Some(u64::MIN + 1)), + ScalarValue::UInt64(Some(u64::MAX)), + )?; + assert_eq!(interval.cardinality().unwrap(), u64::MAX); + + let interval = Interval::try_new( + ScalarValue::Int64(Some(i64::MIN + 1)), + ScalarValue::Int64(Some(i64::MAX)), + )?; + assert_eq!(interval.cardinality().unwrap(), u64::MAX); + + let interval = Interval::try_new( + ScalarValue::Float32(Some(-0.0_f32)), + ScalarValue::Float32(Some(0.0_f32)), + )?; + assert_eq!(interval.cardinality().unwrap(), 2); + + Ok(()) + } + + #[test] + fn test_satisfy_comparison() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + true, + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), None)?, + true, + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + false, + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1500_i64))?, + true, + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + true, + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1500_i64))?, + false, + Interval::make(Some(501_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(999_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + false, + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + false, + Interval::make(Some(2_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + true, + Interval::make(Some(1_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + false, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + true, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(1_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + false, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + true, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(1_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + false, + Interval::make(Some(2_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + true, + Interval::make(Some(1_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + false, + Interval::try_new( + next_value(ScalarValue::Float32(Some(-500.0))), + ScalarValue::Float32(Some(1000.0)), + )?, + Interval::make(Some(-500_f32), Some(500.0_f32))?, + ), + ( + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + true, + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::make(Some(-1000.0_f32), Some(500.0_f32))?, + ), + ( + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + false, + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::try_new( + ScalarValue::Float32(Some(-1000.0_f32)), + prev_value(ScalarValue::Float32(Some(500.0_f32))), + )?, + ), + ( + Interval::make(Some(-1000.0_f64), Some(1000.0_f64))?, + Interval::make(Some(-500.0_f64), Some(500.0_f64))?, + true, + Interval::make(Some(-500.0_f64), Some(1000.0_f64))?, + Interval::make(Some(-500.0_f64), Some(500.0_f64))?, + ), + ]; + for (first, second, includes_endpoints, left_modified, right_modified) in cases { + assert_eq!( + satisfy_greater(&first, &second, !includes_endpoints)?.unwrap(), + (left_modified, right_modified) + ); + } + + let infeasible_cases = vec![ + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), None)?, + false, + ), + ( + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + Interval::make(Some(1500.0_f32), Some(2000.0_f32))?, + false, + ), + ]; + for (first, second, includes_endpoints) in infeasible_cases { + assert_eq!(satisfy_greater(&first, &second, !includes_endpoints)?, None); + } + + Ok(()) + } + + #[test] + fn test_interval_display() { + let interval = Interval::make(Some(0.25_f32), Some(0.50_f32)).unwrap(); + assert_eq!(format!("{}", interval), "[0.25, 0.5]"); + + let interval = Interval::try_new( + ScalarValue::Float32(Some(f32::NEG_INFINITY)), + ScalarValue::Float32(Some(f32::INFINITY)), + ) + .unwrap(); + assert_eq!(format!("{}", interval), "[NULL, NULL]"); + } + + macro_rules! capture_mode_change { + ($TYPE:ty) => { + paste::item! { + capture_mode_change_helper!([], + [], + $TYPE); + } + }; + } + + macro_rules! capture_mode_change_helper { + ($TEST_FN_NAME:ident, $CREATE_FN_NAME:ident, $TYPE:ty) => { + fn $CREATE_FN_NAME(lower: $TYPE, upper: $TYPE) -> Interval { + Interval::try_new( + ScalarValue::try_from(Some(lower as $TYPE)).unwrap(), + ScalarValue::try_from(Some(upper as $TYPE)).unwrap(), + ) + .unwrap() + } + + fn $TEST_FN_NAME(input: ($TYPE, $TYPE), expect_low: bool, expect_high: bool) { + assert!(expect_low || expect_high); + let interval1 = $CREATE_FN_NAME(input.0, input.0); + let interval2 = $CREATE_FN_NAME(input.1, input.1); + let result = interval1.add(&interval2).unwrap(); + let without_fe = $CREATE_FN_NAME(input.0 + input.1, input.0 + input.1); + assert!( + (!expect_low || result.lower < without_fe.lower) + && (!expect_high || result.upper > without_fe.upper) + ); + } + }; + } + + capture_mode_change!(f32); + capture_mode_change!(f64); + + #[cfg(all( + any(target_arch = "x86_64", target_arch = "aarch64"), + not(target_os = "windows") + ))] + #[test] + fn test_add_intervals_lower_affected_f32() { + // Lower is affected + let lower = f32::from_bits(1073741887); //1000000000000000000000000111111 + let upper = f32::from_bits(1098907651); //1000001100000000000000000000011 + capture_mode_change_f32((lower, upper), true, false); + + // Upper is affected + let lower = f32::from_bits(1072693248); //111111111100000000000000000000 + let upper = f32::from_bits(715827883); //101010101010101010101010101011 + capture_mode_change_f32((lower, upper), false, true); + + // Lower is affected + let lower = 1.0; // 0x3FF0000000000000 + let upper = 0.3; // 0x3FD3333333333333 + capture_mode_change_f64((lower, upper), true, false); + + // Upper is affected + let lower = 1.4999999999999998; // 0x3FF7FFFFFFFFFFFF + let upper = 0.000_000_000_000_000_022_044_604_925_031_31; // 0x3C796A6B413BB21F + capture_mode_change_f64((lower, upper), false, true); + } + + #[cfg(any( + not(any(target_arch = "x86_64", target_arch = "aarch64")), + target_os = "windows" + ))] + #[test] + fn test_next_impl_add_intervals_f64() { + let lower = 1.5; + let upper = 1.5; + capture_mode_change_f64((lower, upper), true, true); + + let lower = 1.5; + let upper = 1.5; + capture_mode_change_f32((lower, upper), true, true); + } +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index d35233bc39d2..077681d21725 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -26,10 +26,21 @@ //! The [expr_fn] module contains functions for creating expressions. mod accumulator; -pub mod aggregate_function; -pub mod array_expressions; mod built_in_function; +mod built_in_window_function; mod columnar_value; +mod literal; +mod nullif; +mod operator; +mod partition_evaluator; +mod signature; +mod table_source; +mod udaf; +mod udf; +mod udwf; + +pub mod aggregate_function; +pub mod array_expressions; pub mod conditional_expressions; pub mod expr; pub mod expr_fn; @@ -37,31 +48,22 @@ pub mod expr_rewriter; pub mod expr_schema; pub mod field_util; pub mod function; -mod literal; +pub mod interval_arithmetic; pub mod logical_plan; -mod nullif; -mod operator; -mod partition_evaluator; -mod signature; -pub mod struct_expressions; -mod table_source; pub mod tree_node; pub mod type_coercion; -mod udaf; -mod udf; -mod udwf; pub mod utils; pub mod window_frame; -pub mod window_function; pub mod window_state; pub use accumulator::Accumulator; pub use aggregate_function::AggregateFunction; pub use built_in_function::BuiltinScalarFunction; +pub use built_in_window_function::BuiltInWindowFunction; pub use columnar_value::ColumnarValue; pub use expr::{ Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, - Like, TryCast, + Like, ScalarFunctionDefinition, TryCast, WindowFunctionDefinition, }; pub use expr_fn::*; pub use expr_schema::ExprSchemable; @@ -74,13 +76,14 @@ pub use logical_plan::*; pub use nullif::SUPPORTED_NULLIF_TYPES; pub use operator::Operator; pub use partition_evaluator::PartitionEvaluator; -pub use signature::{Signature, TypeSignature, Volatility}; +pub use signature::{ + FuncMonotonicity, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, +}; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; -pub use udf::ScalarUDF; -pub use udwf::WindowUDF; +pub use udf::{ScalarUDF, ScalarUDFImpl}; +pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; -pub use window_function::{BuiltInWindowFunction, WindowFunction}; #[cfg(test)] #[ctor::ctor] diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index effc31553819..2f04729af2ed 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -43,19 +43,19 @@ pub trait TimestampLiteral { impl Literal for &str { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::from(*self)) } } impl Literal for String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::from(self.as_ref())) } } impl Literal for &String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::from(self.as_ref())) } } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 6171d43b37f5..847fbbbf61c7 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -17,6 +17,13 @@ //! This module provides a builder for creating LogicalPlans +use std::any::Any; +use std::cmp::Ordering; +use std::collections::{HashMap, HashSet}; +use std::convert::TryFrom; +use std::iter::zip; +use std::sync::Arc; + use crate::dml::{CopyOptions, CopyTo}; use crate::expr::Alias; use crate::expr_rewriter::{ @@ -24,37 +31,29 @@ use crate::expr_rewriter::{ normalize_col_with_schemas_and_ambiguity_check, normalize_cols, rewrite_sort_cols_by_aggs, }; +use crate::logical_plan::{ + Aggregate, Analyze, CrossJoin, Distinct, DistinctOn, EmptyRelation, Explain, Filter, + Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, + Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, + Window, +}; use crate::type_coercion::binary::comparison_coercion; -use crate::utils::{columnize_expr, compare_sort_expr}; -use crate::{ - and, binary_expr, DmlStatement, Operator, TableProviderFilterPushDown, WriteOp, +use crate::utils::{ + can_hash, columnize_expr, compare_sort_expr, expand_qualified_wildcard, + expand_wildcard, find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, }; use crate::{ - logical_plan::{ - Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join, - JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, - Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, - Window, - }, - utils::{ - can_hash, expand_qualified_wildcard, expand_wildcard, - find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, - }, - Expr, ExprSchemable, TableSource, + and, binary_expr, DmlStatement, Expr, ExprSchemable, Operator, + TableProviderFilterPushDown, TableSource, WriteOp, }; + use arrow::datatypes::{DataType, Schema, SchemaRef}; -use datafusion_common::plan_err; -use datafusion_common::UnnestOptions; +use datafusion_common::display::ToStringifiedPlan; use datafusion_common::{ - display::ToStringifiedPlan, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, - FileType, FunctionalDependencies, OwnedTableReference, Result, ScalarValue, - TableReference, ToDFSchema, + get_target_functional_dependencies, plan_datafusion_err, plan_err, Column, DFField, + DFSchema, DFSchemaRef, DataFusionError, FileType, OwnedTableReference, Result, + ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; -use std::any::Any; -use std::cmp::Ordering; -use std::collections::{HashMap, HashSet}; -use std::convert::TryFrom; -use std::sync::Arc; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -166,7 +165,7 @@ impl LogicalPlanBuilder { let data_type = expr.get_type(&empty_schema)?; if let Some(prev_data_type) = &field_types[j] { if prev_data_type != &data_type { - return plan_err!("Inconsistent data type across values list at row {i} column {j}"); + return plan_err!("Inconsistent data type across values list at row {i} column {j}. Was {prev_data_type} but found {data_type}") } } Ok(Some(data_type)) @@ -282,53 +281,9 @@ impl LogicalPlanBuilder { projection: Option>, filters: Vec, ) -> Result { - let table_name = table_name.into(); - - if table_name.table().is_empty() { - return plan_err!("table_name cannot be empty"); - } - - let schema = table_source.schema(); - let func_dependencies = FunctionalDependencies::new_from_constraints( - table_source.constraints(), - schema.fields.len(), - ); - - let projected_schema = projection - .as_ref() - .map(|p| { - let projected_func_dependencies = - func_dependencies.project_functional_dependencies(p, p.len()); - DFSchema::new_with_metadata( - p.iter() - .map(|i| { - DFField::from_qualified( - table_name.clone(), - schema.field(*i).clone(), - ) - }) - .collect(), - schema.metadata().clone(), - ) - .map(|df_schema| { - df_schema.with_functional_dependencies(projected_func_dependencies) - }) - }) - .unwrap_or_else(|| { - DFSchema::try_from_qualified_schema(table_name.clone(), &schema).map( - |df_schema| df_schema.with_functional_dependencies(func_dependencies), - ) - })?; - - let table_scan = LogicalPlan::TableScan(TableScan { - table_name, - source: table_source, - projected_schema: Arc::new(projected_schema), - projection, - filters, - fetch: None, - }); - Ok(Self::from(table_scan)) + TableScan::try_new(table_name, table_source, projection, filters, None) + .map(LogicalPlan::TableScan) + .map(Self::from) } /// Wrap a plan in a window @@ -337,7 +292,7 @@ impl LogicalPlanBuilder { window_exprs: Vec, ) -> Result { let mut plan = input; - let mut groups = group_window_expr_by_sort_keys(&window_exprs)?; + let mut groups = group_window_expr_by_sort_keys(window_exprs)?; // To align with the behavior of PostgreSQL, we want the sort_keys sorted as same rule as PostgreSQL that first // we compare the sort key themselves and if one window's sort keys are a prefix of another // put the window with more sort keys first. so more deeply sorted plans gets nested further down as children. @@ -359,7 +314,7 @@ impl LogicalPlanBuilder { key_b.len().cmp(&key_a.len()) }); for (_, exprs) in groups { - let window_exprs = exprs.into_iter().cloned().collect::>(); + let window_exprs = exprs.into_iter().collect::>(); // Partition and sorting is done at physical level, see the EnforceDistribution // and EnforceSorting rules. plan = LogicalPlanBuilder::from(plan) @@ -373,7 +328,7 @@ impl LogicalPlanBuilder { self, expr: impl IntoIterator>, ) -> Result { - Ok(Self::from(project(self.plan, expr)?)) + project(self.plan, expr).map(Self::from) } /// Select the given column indices @@ -389,10 +344,9 @@ impl LogicalPlanBuilder { /// Apply a filter pub fn filter(self, expr: impl Into) -> Result { let expr = normalize_col(expr.into(), &self.plan)?; - Ok(Self::from(LogicalPlan::Filter(Filter::try_new( - expr, - Arc::new(self.plan), - )?))) + Filter::try_new(expr, Arc::new(self.plan)) + .map(LogicalPlan::Filter) + .map(Self::from) } /// Make a builder for a prepare logical plan from the builder's plan @@ -420,7 +374,7 @@ impl LogicalPlanBuilder { /// Apply an alias pub fn alias(self, alias: impl Into) -> Result { - Ok(Self::from(subquery_alias(self.plan, alias)?)) + subquery_alias(self.plan, alias).map(Self::from) } /// Add missing sort columns to all downstream projection @@ -475,7 +429,7 @@ impl LogicalPlanBuilder { Self::ambiguous_distinct_check(&missing_exprs, missing_cols, &expr)?; } expr.extend(missing_exprs); - Ok(project((*input).clone(), expr)?) + project((*input).clone(), expr) } _ => { let is_distinct = @@ -491,7 +445,7 @@ impl LogicalPlanBuilder { ) }) .collect::>>()?; - curr_plan.with_new_inputs(&new_inputs) + curr_plan.with_new_exprs(curr_plan.expressions(), &new_inputs) } } } @@ -582,15 +536,14 @@ impl LogicalPlanBuilder { fetch: None, }); - Ok(Self::from(LogicalPlan::Projection(Projection::try_new( - new_expr, - Arc::new(sort_plan), - )?))) + Projection::try_new(new_expr, Arc::new(sort_plan)) + .map(LogicalPlan::Projection) + .map(Self::from) } /// Apply a union, preserving duplicate rows pub fn union(self, plan: LogicalPlan) -> Result { - Ok(Self::from(union(self.plan, plan)?)) + union(self.plan, plan).map(Self::from) } /// Apply a union, removing duplicate rows @@ -598,23 +551,44 @@ impl LogicalPlanBuilder { let left_plan: LogicalPlan = self.plan; let right_plan: LogicalPlan = plan; - Ok(Self::from(LogicalPlan::Distinct(Distinct { - input: Arc::new(union(left_plan, right_plan)?), - }))) + Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( + union(left_plan, right_plan)?, + ))))) } /// Apply deduplication: Only distinct (different) values are returned) pub fn distinct(self) -> Result { - Ok(Self::from(LogicalPlan::Distinct(Distinct { - input: Arc::new(self.plan), - }))) + Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( + self.plan, + ))))) + } + + /// Project first values of the specified expression list according to the provided + /// sorting expressions grouped by the `DISTINCT ON` clause expressions. + pub fn distinct_on( + self, + on_expr: Vec, + select_expr: Vec, + sort_expr: Option>, + ) -> Result { + Ok(Self::from(LogicalPlan::Distinct(Distinct::On( + DistinctOn::try_new(on_expr, select_expr, sort_expr, Arc::new(self.plan))?, + )))) } - /// Apply a join with on constraint. + /// Apply a join to `right` using explicitly specified columns and an + /// optional filter expression. /// - /// Filter expression expected to contain non-equality predicates that can not be pushed - /// down to any of join inputs. - /// In case of outer join, filter applied to only matched rows. + /// See [`join_on`](Self::join_on) for a more concise way to specify the + /// join condition. Since DataFusion will automatically identify and + /// optimize equality predicates there is no performance difference between + /// this function and `join_on` + /// + /// `left_cols` and `right_cols` are used to form "equijoin" predicates (see + /// example below), which are then combined with the optional `filter` + /// expression. + /// + /// Note that in case of outer join, the `filter` is applied to only matched rows. pub fn join( self, right: LogicalPlan, @@ -625,6 +599,63 @@ impl LogicalPlanBuilder { self.join_detailed(right, join_type, join_keys, filter, false) } + /// Apply a join with using the specified expressions. + /// + /// Note that DataFusion automatically optimizes joins, including + /// identifying and optimizing equality predicates. + /// + /// # Example + /// + /// ``` + /// # use datafusion_expr::{Expr, col, LogicalPlanBuilder, + /// # logical_plan::builder::LogicalTableSource, logical_plan::JoinType,}; + /// # use std::sync::Arc; + /// # use arrow::datatypes::{Schema, DataType, Field}; + /// # use datafusion_common::Result; + /// # fn main() -> Result<()> { + /// let example_schema = Arc::new(Schema::new(vec![ + /// Field::new("a", DataType::Int32, false), + /// Field::new("b", DataType::Int32, false), + /// Field::new("c", DataType::Int32, false), + /// ])); + /// let table_source = Arc::new(LogicalTableSource::new(example_schema)); + /// let left_table = table_source.clone(); + /// let right_table = table_source.clone(); + /// + /// let right_plan = LogicalPlanBuilder::scan("right", right_table, None)?.build()?; + /// + /// // Form the expression `(left.a != right.a)` AND `(left.b != right.b)` + /// let exprs = vec![ + /// col("left.a").eq(col("right.a")), + /// col("left.b").not_eq(col("right.b")) + /// ]; + /// + /// // Perform the equivalent of `left INNER JOIN right ON (a != a2 AND b != b2)` + /// // finding all pairs of rows from `left` and `right` where + /// // where `a = a2` and `b != b2`. + /// let plan = LogicalPlanBuilder::scan("left", left_table, None)? + /// .join_on(right_plan, JoinType::Inner, exprs)? + /// .build()?; + /// # Ok(()) + /// # } + /// ``` + pub fn join_on( + self, + right: LogicalPlan, + join_type: JoinType, + on_exprs: impl IntoIterator, + ) -> Result { + let filter = on_exprs.into_iter().reduce(Expr::and); + + self.join_detailed( + right, + join_type, + (Vec::::new(), Vec::::new()), + filter, + false, + ) + } + pub(crate) fn normalize( plan: &LogicalPlan, column: impl Into + Clone, @@ -638,8 +669,14 @@ impl LogicalPlanBuilder { ) } - /// Apply a join with on constraint and specified null equality - /// If null_equals_null is true then null == null, else null != null + /// Apply a join with on constraint and specified null equality. + /// + /// The behavior is the same as [`join`](Self::join) except that it allows + /// specifying the null equality behavior. + /// + /// If `null_equals_null=true`, rows where both join keys are `null` will be + /// emitted. Otherwise rows where either or both join keys are `null` will be + /// omitted. pub fn join_detailed( self, right: LogicalPlan, @@ -869,11 +906,12 @@ impl LogicalPlanBuilder { ) -> Result { let group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; - Ok(Self::from(LogicalPlan::Aggregate(Aggregate::try_new( - Arc::new(self.plan), - group_expr, - aggr_expr, - )?))) + + let group_expr = + add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; + Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr) + .map(LogicalPlan::Aggregate) + .map(Self::from) } /// Create an expression to represent the explanation of the plan @@ -1025,9 +1063,9 @@ impl LogicalPlanBuilder { self.plan.schema().clone(), right.schema().clone(), )?.ok_or_else(|| - DataFusionError::Plan(format!( + plan_datafusion_err!( "can't create join plan, join key should belong to one input, error key: ({normalized_left_key},{normalized_right_key})" - ))) + )) }) .collect::>>()?; @@ -1131,10 +1169,46 @@ pub fn build_join_schema( ); let mut metadata = left.metadata().clone(); metadata.extend(right.metadata().clone()); - Ok(DFSchema::new_with_metadata(fields, metadata)? - .with_functional_dependencies(func_dependencies)) + let schema = DFSchema::new_with_metadata(fields, metadata)?; + schema.with_functional_dependencies(func_dependencies) } +/// Add additional "synthetic" group by expressions based on functional +/// dependencies. +/// +/// For example, if we are grouping on `[c1]`, and we know from +/// functional dependencies that column `c1` determines `c2`, this function +/// adds `c2` to the group by list. +/// +/// This allows MySQL style selects like +/// `SELECT col FROM t WHERE pk = 5` if col is unique +fn add_group_by_exprs_from_dependencies( + mut group_expr: Vec, + schema: &DFSchemaRef, +) -> Result> { + // Names of the fields produced by the GROUP BY exprs for example, `GROUP BY + // c1 + 1` produces an output field named `"c1 + 1"` + let mut group_by_field_names = group_expr + .iter() + .map(|e| e.display_name()) + .collect::>>()?; + + if let Some(target_indices) = + get_target_functional_dependencies(schema, &group_by_field_names) + { + for idx in target_indices { + let field = schema.field(idx); + let expr = + Expr::Column(Column::new(field.qualifier().cloned(), field.name())); + let expr_name = expr.display_name()?; + if !group_by_field_names.contains(&expr_name) { + group_by_field_names.push(expr_name); + group_expr.push(expr); + } + } + } + Ok(group_expr) +} /// Errors if one or more expressions have equal names. pub(crate) fn validate_unique_names<'a>( node_name: &str, @@ -1179,9 +1253,8 @@ pub fn project_with_column_index( }) .collect::>(); - Ok(LogicalPlan::Projection(Projection::try_new_with_schema( - alias_expr, input, schema, - )?)) + Projection::try_new_with_schema(alias_expr, input, schema) + .map(LogicalPlan::Projection) } /// Union two logical plans. @@ -1196,39 +1269,36 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result>>()? - .to_dfschema()?; + left_field.data_type() + ) + })?; + + Ok(DFField::new( + left_field.qualifier().cloned(), + left_field.name(), + data_type, + nullable, + )) + }) + .collect::>>()? + .to_dfschema()?; let inputs = vec![left_plan, right_plan] .into_iter() - .flat_map(|p| match p { - LogicalPlan::Union(Union { inputs, .. }) => inputs, - other_plan => vec![Arc::new(other_plan)], - }) .map(|p| { let plan = coerce_plan_expr_for_schema(&p, &union_schema)?; match plan { @@ -1269,21 +1339,23 @@ pub fn project( for e in expr { let e = e.into(); match e { - Expr::Wildcard => { + Expr::Wildcard { qualifier: None } => { projected_expr.extend(expand_wildcard(input_schema, &plan, None)?) } - Expr::QualifiedWildcard { ref qualifier } => projected_expr - .extend(expand_qualified_wildcard(qualifier, input_schema, None)?), + Expr::Wildcard { + qualifier: Some(qualifier), + } => projected_expr.extend(expand_qualified_wildcard( + &qualifier, + input_schema, + None, + )?), _ => projected_expr .push(columnize_expr(normalize_col(e, &plan)?, input_schema)), } } validate_unique_names("Projections", projected_expr.iter())?; - Ok(LogicalPlan::Projection(Projection::try_new( - projected_expr, - Arc::new(plan.clone()), - )?)) + Projection::try_new(projected_expr, Arc::new(plan)).map(LogicalPlan::Projection) } /// Create a SubqueryAlias to wrap a LogicalPlan. @@ -1291,9 +1363,7 @@ pub fn subquery_alias( plan: LogicalPlan, alias: impl Into, ) -> Result { - Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( - plan, alias, - )?)) + SubqueryAlias::try_new(Arc::new(plan), alias).map(LogicalPlan::SubqueryAlias) } /// Create a LogicalPlanBuilder representing a scan of a table with the provided name and schema. @@ -1456,11 +1526,11 @@ pub fn unnest_with_options( }) .collect::>(); - let schema = Arc::new( - DFSchema::new_with_metadata(fields, input_schema.metadata().clone())? - // We can use the existing functional dependencies: - .with_functional_dependencies(input_schema.functional_dependencies().clone()), - ); + let metadata = input_schema.metadata().clone(); + let df_schema = DFSchema::new_with_metadata(fields, metadata)?; + // We can use the existing functional dependencies: + let deps = input_schema.functional_dependencies().clone(); + let schema = Arc::new(df_schema.with_functional_dependencies(deps)?); Ok(LogicalPlan::Unnest(Unnest { input: Arc::new(input), @@ -1472,16 +1542,12 @@ pub fn unnest_with_options( #[cfg(test)] mod tests { - use crate::logical_plan::StringifiedPlan; - use crate::{col, in_subquery, lit, scalar_subquery, sum}; - use crate::{expr, expr_fn::exists}; - use super::*; + use crate::logical_plan::StringifiedPlan; + use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery, sum}; use arrow::datatypes::{DataType, Field}; - use datafusion_common::{ - FunctionalDependence, OwnedTableReference, SchemaError, TableReference, - }; + use datafusion_common::{OwnedTableReference, SchemaError, TableReference}; #[test] fn plan_builder_simple() -> Result<()> { @@ -1581,7 +1647,7 @@ mod tests { let plan = table_scan(Some("t1"), &employee_schema(), None)? .join_using(t2, JoinType::Inner, vec!["id"])? - .project(vec![Expr::Wildcard])? + .project(vec![Expr::Wildcard { qualifier: None }])? .build()?; // id column should only show up once in projection @@ -1596,7 +1662,7 @@ mod tests { } #[test] - fn plan_builder_union_combined_single_union() -> Result<()> { + fn plan_builder_union() -> Result<()> { let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))?; @@ -1607,11 +1673,12 @@ mod tests { .union(plan.build()?)? .build()?; - // output has only one union let expected = "Union\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ + \n Union\ + \n Union\ + \n TableScan: employee_csv projection=[state, salary]\ + \n TableScan: employee_csv projection=[state, salary]\ + \n TableScan: employee_csv projection=[state, salary]\ \n TableScan: employee_csv projection=[state, salary]"; assert_eq!(expected, format!("{plan:?}")); @@ -1620,7 +1687,7 @@ mod tests { } #[test] - fn plan_builder_union_distinct_combined_single_union() -> Result<()> { + fn plan_builder_union_distinct() -> Result<()> { let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))?; @@ -1778,13 +1845,16 @@ mod tests { .project(vec![col("id"), col("first_name").alias("id")]); match plan { - Err(DataFusionError::SchemaError(SchemaError::AmbiguousReference { - field: - Column { - relation: Some(OwnedTableReference::Bare { table }), - name, - }, - })) => { + Err(DataFusionError::SchemaError( + SchemaError::AmbiguousReference { + field: + Column { + relation: Some(OwnedTableReference::Bare { table }), + name, + }, + }, + _, + )) => { assert_eq!("employee_csv", table); assert_eq!("id", &name); Ok(()) @@ -1805,13 +1875,16 @@ mod tests { .aggregate(vec![col("state")], vec![sum(col("salary")).alias("state")]); match plan { - Err(DataFusionError::SchemaError(SchemaError::AmbiguousReference { - field: - Column { - relation: Some(OwnedTableReference::Bare { table }), - name, - }, - })) => { + Err(DataFusionError::SchemaError( + SchemaError::AmbiguousReference { + field: + Column { + relation: Some(OwnedTableReference::Bare { table }), + name, + }, + }, + _, + )) => { assert_eq!("employee_csv", table); assert_eq!("state", &name); Ok(()) @@ -1981,21 +2054,4 @@ mod tests { Ok(()) } - - #[test] - fn test_get_updated_id_keys() { - let fund_dependencies = - FunctionalDependencies::new(vec![FunctionalDependence::new( - vec![1], - vec![0, 1, 2], - true, - )]); - let res = fund_dependencies.project_functional_dependencies(&[1, 2], 2); - let expected = FunctionalDependencies::new(vec![FunctionalDependence::new( - vec![0], - vec![0, 1], - true, - )]); - assert_eq!(res, expected); - } } diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index dc247da3642c..e74992d99373 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -112,9 +112,10 @@ impl DdlStatement { match self.0 { DdlStatement::CreateExternalTable(CreateExternalTable { ref name, + constraints, .. }) => { - write!(f, "CreateExternalTable: {name:?}") + write!(f, "CreateExternalTable: {name:?}{constraints}") } DdlStatement::CreateMemoryTable(CreateMemoryTable { name, @@ -191,6 +192,10 @@ pub struct CreateExternalTable { pub unbounded: bool, /// Table(provider) specific options pub options: HashMap, + /// The list of constraints in the schema, such as primary key, unique, etc. + pub constraints: Constraints, + /// Default values for columns + pub column_defaults: HashMap, } // Hashing refers to a subset of fields considered in PartialEq. @@ -225,6 +230,8 @@ pub struct CreateMemoryTable { pub if_not_exists: bool, /// Option to replace table content if table already exists pub or_replace: bool, + /// Default values for columns + pub column_defaults: Vec<(String, Expr)>, } /// Creates a view. diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 8316417138bd..bc722dd69ace 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -33,10 +33,11 @@ pub use ddl::{ }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ - Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, EmptyRelation, Explain, - Extension, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, - PlanType, Prepare, Projection, Repartition, Sort, StringifiedPlan, Subquery, - SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, + projection_schema, Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, + DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, + JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, + Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, + ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ SetVariable, Statement, TransactionAccessMode, TransactionConclusion, TransactionEnd, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index dfc83f9eec76..93a38fb40df5 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -17,25 +17,31 @@ //! Logical plan types +use std::collections::{HashMap, HashSet}; +use std::fmt::{self, Debug, Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use super::dml::CopyTo; +use super::DdlStatement; use crate::dml::CopyOptions; -use crate::expr::{Alias, Exists, InSubquery, Placeholder}; -use crate::expr_rewriter::create_col_from_scalar_expr; +use crate::expr::{ + Alias, Exists, InSubquery, Placeholder, Sort as SortExpr, WindowFunction, +}; +use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols}; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre, + split_conjunction, }; use crate::{ - build_join_schema, Expr, ExprSchemable, TableProviderFilterPushDown, TableSource, + build_join_schema, expr_vec_fmt, BinaryExpr, BuiltInWindowFunction, + CreateMemoryTable, CreateView, Expr, ExprSchemable, LogicalPlanBuilder, Operator, + TableProviderFilterPushDown, TableSource, WindowFunctionDefinition, }; -use crate::{ - expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, LogicalPlanBuilder, Operator, -}; - -use super::dml::CopyTo; -use super::DdlStatement; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ @@ -44,18 +50,14 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, - DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies, - OwnedTableReference, Result, ScalarValue, UnnestOptions, + DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, + FunctionalDependencies, OwnedTableReference, ParamValues, Result, UnnestOptions, }; + // backwards compatibility pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; -use std::collections::{HashMap, HashSet}; -use std::fmt::{self, Debug, Display, Formatter}; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; - /// A LogicalPlan represents the different types of relational /// operators (such as Projection, Filter, etc) and can be created by /// the SQL query planner and the DataFrame API. @@ -142,7 +144,7 @@ pub enum LogicalPlan { Prepare(Prepare), /// Data Manipulaton Language (DML): Insert / Update / Delete Dml(DmlStatement), - /// Data Definition Language (DDL): CREATE / DROP TABLES / VIEWS / SCHEMAs + /// Data Definition Language (DDL): CREATE / DROP TABLES / VIEWS / SCHEMAS Ddl(DdlStatement), /// `COPY TO` for writing plan results to files Copy(CopyTo), @@ -165,7 +167,8 @@ impl LogicalPlan { }) => projected_schema, LogicalPlan::Projection(Projection { schema, .. }) => schema, LogicalPlan::Filter(Filter { input, .. }) => input.schema(), - LogicalPlan::Distinct(Distinct { input }) => input.schema(), + LogicalPlan::Distinct(Distinct::All(input)) => input.schema(), + LogicalPlan::Distinct(Distinct::On(DistinctOn { schema, .. })) => schema, LogicalPlan::Window(Window { schema, .. }) => schema, LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema, LogicalPlan::Sort(Sort { input, .. }) => input.schema(), @@ -369,6 +372,16 @@ impl LogicalPlan { LogicalPlan::Unnest(Unnest { column, .. }) => { f(&Expr::Column(column.clone())) } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + })) => on_expr + .iter() + .chain(select_expr.iter()) + .chain(sort_expr.clone().unwrap_or(vec![]).iter()) + .try_for_each(f), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Subquery(_) @@ -379,7 +392,7 @@ impl LogicalPlan { | LogicalPlan::Analyze(_) | LogicalPlan::Explain(_) | LogicalPlan::Union(_) - | LogicalPlan::Distinct(_) + | LogicalPlan::Distinct(Distinct::All(_)) | LogicalPlan::Dml(_) | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) @@ -407,7 +420,9 @@ impl LogicalPlan { LogicalPlan::Union(Union { inputs, .. }) => { inputs.iter().map(|arc| arc.as_ref()).collect() } - LogicalPlan::Distinct(Distinct { input }) => vec![input], + LogicalPlan::Distinct( + Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + ) => vec![input], LogicalPlan::Explain(explain) => vec![&explain.plan], LogicalPlan::Analyze(analyze) => vec![&analyze.input], LogicalPlan::Dml(write) => vec![&write.input], @@ -463,8 +478,11 @@ impl LogicalPlan { Ok(Some(agg.group_expr.as_slice()[0].clone())) } } + LogicalPlan::Distinct(Distinct::On(DistinctOn { select_expr, .. })) => { + Ok(Some(select_expr[0].clone())) + } LogicalPlan::Filter(Filter { input, .. }) - | LogicalPlan::Distinct(Distinct { input, .. }) + | LogicalPlan::Distinct(Distinct::All(input)) | LogicalPlan::Sort(Sort { input, .. }) | LogicalPlan::Limit(Limit { input, .. }) | LogicalPlan::Repartition(Repartition { input, .. }) @@ -526,39 +544,9 @@ impl LogicalPlan { } /// Returns a copy of this `LogicalPlan` with the new inputs + #[deprecated(since = "35.0.0", note = "please use `with_new_exprs` instead")] pub fn with_new_inputs(&self, inputs: &[LogicalPlan]) -> Result { - // with_new_inputs use original expression, - // so we don't need to recompute Schema. - match &self { - LogicalPlan::Projection(projection) => { - Ok(LogicalPlan::Projection(Projection::try_new_with_schema( - projection.expr.to_vec(), - Arc::new(inputs[0].clone()), - projection.schema.clone(), - )?)) - } - LogicalPlan::Window(Window { - window_expr, - schema, - .. - }) => Ok(LogicalPlan::Window(Window { - input: Arc::new(inputs[0].clone()), - window_expr: window_expr.to_vec(), - schema: schema.clone(), - })), - LogicalPlan::Aggregate(Aggregate { - group_expr, - aggr_expr, - schema, - .. - }) => Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema( - Arc::new(inputs[0].clone()), - group_expr.to_vec(), - aggr_expr.to_vec(), - schema.clone(), - )?)), - _ => self.with_new_exprs(self.expressions(), inputs), - } + self.with_new_exprs(self.expressions(), inputs) } /// Returns a new `LogicalPlan` based on `self` with inputs and @@ -580,22 +568,17 @@ impl LogicalPlan { /// // create new plan using rewritten_exprs in same position /// let new_plan = plan.new_with_exprs(rewritten_exprs, new_inputs); /// ``` - /// - /// Note: sometimes [`Self::with_new_exprs`] will use schema of - /// original plan, it will not change the scheam. Such as - /// `Projection/Aggregate/Window` pub fn with_new_exprs( &self, mut expr: Vec, inputs: &[LogicalPlan], ) -> Result { match self { - LogicalPlan::Projection(Projection { schema, .. }) => { - Ok(LogicalPlan::Projection(Projection::try_new_with_schema( - expr, - Arc::new(inputs[0].clone()), - schema.clone(), - )?)) + // Since expr may be different than the previous expr, schema of the projection + // may change. We need to use try_new method instead of try_new_with_schema method. + LogicalPlan::Projection(Projection { .. }) => { + Projection::try_new(expr, Arc::new(inputs[0].clone())) + .map(LogicalPlan::Projection) } LogicalPlan::Dml(DmlStatement { table_name, @@ -672,10 +655,8 @@ impl LogicalPlan { let mut remove_aliases = RemoveAliases {}; let predicate = predicate.rewrite(&mut remove_aliases)?; - Ok(LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(inputs[0].clone()), - )?)) + Filter::try_new(predicate, Arc::new(inputs[0].clone())) + .map(LogicalPlan::Filter) } LogicalPlan::Repartition(Repartition { partitioning_scheme, @@ -698,30 +679,17 @@ impl LogicalPlan { })) } }, - LogicalPlan::Window(Window { - window_expr, - schema, - .. - }) => { + LogicalPlan::Window(Window { window_expr, .. }) => { assert_eq!(window_expr.len(), expr.len()); - Ok(LogicalPlan::Window(Window { - input: Arc::new(inputs[0].clone()), - window_expr: expr, - schema: schema.clone(), - })) + Window::try_new(expr, Arc::new(inputs[0].clone())) + .map(LogicalPlan::Window) } - LogicalPlan::Aggregate(Aggregate { - group_expr, schema, .. - }) => { + LogicalPlan::Aggregate(Aggregate { group_expr, .. }) => { // group exprs are the first expressions let agg_expr = expr.split_off(group_expr.len()); - Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema( - Arc::new(inputs[0].clone()), - expr, - agg_expr, - schema.clone(), - )?)) + Aggregate::try_new(Arc::new(inputs[0].clone()), expr, agg_expr) + .map(LogicalPlan::Aggregate) } LogicalPlan::Sort(Sort { fetch, .. }) => Ok(LogicalPlan::Sort(Sort { expr, @@ -790,10 +758,8 @@ impl LogicalPlan { })) } LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => { - Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( - inputs[0].clone(), - alias.clone(), - )?)) + SubqueryAlias::try_new(Arc::new(inputs[0].clone()), alias.clone()) + .map(LogicalPlan::SubqueryAlias) } LogicalPlan::Limit(Limit { skip, fetch, .. }) => { Ok(LogicalPlan::Limit(Limit { @@ -806,6 +772,7 @@ impl LogicalPlan { name, if_not_exists, or_replace, + column_defaults, .. })) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { @@ -814,6 +781,7 @@ impl LogicalPlan { name: name.clone(), if_not_exists: *if_not_exists, or_replace: *or_replace, + column_defaults: column_defaults.clone(), }, ))), LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { @@ -830,15 +798,43 @@ impl LogicalPlan { LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension { node: e.node.from_template(&expr, inputs), })), - LogicalPlan::Union(Union { schema, .. }) => Ok(LogicalPlan::Union(Union { - inputs: inputs.iter().cloned().map(Arc::new).collect(), - schema: schema.clone(), - })), - LogicalPlan::Distinct(Distinct { .. }) => { - Ok(LogicalPlan::Distinct(Distinct { - input: Arc::new(inputs[0].clone()), + LogicalPlan::Union(Union { schema, .. }) => { + let input_schema = inputs[0].schema(); + // If inputs are not pruned do not change schema. + let schema = if schema.fields().len() == input_schema.fields().len() { + schema + } else { + input_schema + }; + Ok(LogicalPlan::Union(Union { + inputs: inputs.iter().cloned().map(Arc::new).collect(), + schema: schema.clone(), })) } + LogicalPlan::Distinct(distinct) => { + let distinct = match distinct { + Distinct::All(_) => Distinct::All(Arc::new(inputs[0].clone())), + Distinct::On(DistinctOn { + on_expr, + select_expr, + .. + }) => { + let sort_expr = expr.split_off(on_expr.len() + select_expr.len()); + let select_expr = expr.split_off(on_expr.len()); + Distinct::On(DistinctOn::try_new( + expr, + select_expr, + if !sort_expr.is_empty() { + Some(sort_expr) + } else { + None + }, + Arc::new(inputs[0].clone()), + )?) + } + }; + Ok(LogicalPlan::Distinct(distinct)) + } LogicalPlan::Analyze(a) => { assert!(expr.is_empty()); assert_eq!(inputs.len(), 1); @@ -848,19 +844,19 @@ impl LogicalPlan { input: Arc::new(inputs[0].clone()), })) } - LogicalPlan::Explain(_) => { - // Explain should be handled specially in the optimizers; - // If this check cannot pass it means some optimizer pass is - // trying to optimize Explain directly - if expr.is_empty() { - return plan_err!("Invalid EXPLAIN command. Expression is empty"); - } - - if inputs.is_empty() { - return plan_err!("Invalid EXPLAIN command. Inputs are empty"); - } - - Ok(self.clone()) + LogicalPlan::Explain(e) => { + assert!( + expr.is_empty(), + "Invalid EXPLAIN command. Expression should empty" + ); + assert_eq!(inputs.len(), 1, "Invalid EXPLAIN command. Inputs are empty"); + Ok(LogicalPlan::Explain(Explain { + verbose: e.verbose, + plan: Arc::new(inputs[0].clone()), + stringified_plans: e.stringified_plans.clone(), + schema: e.schema.clone(), + logical_optimization_succeeded: e.logical_optimization_succeeded, + })) } LogicalPlan::Prepare(Prepare { name, data_types, .. @@ -916,7 +912,7 @@ impl LogicalPlan { // We can use the existing functional dependencies as is: .with_functional_dependencies( input.schema().functional_dependencies().clone(), - ), + )?, ); Ok(LogicalPlan::Unnest(Unnest { @@ -928,40 +924,71 @@ impl LogicalPlan { } } } - /// Convert a prepared [`LogicalPlan`] into its inner logical plan - /// with all params replaced with their corresponding values + /// Replaces placeholder param values (like `$1`, `$2`) in [`LogicalPlan`] + /// with the specified `param_values`. + /// + /// [`LogicalPlan::Prepare`] are + /// converted to their inner logical plan for execution. + /// + /// # Example + /// ``` + /// # use arrow::datatypes::{Field, Schema, DataType}; + /// use datafusion_common::ScalarValue; + /// # use datafusion_expr::{lit, col, LogicalPlanBuilder, logical_plan::table_scan, placeholder}; + /// # let schema = Schema::new(vec![ + /// # Field::new("id", DataType::Int32, false), + /// # ]); + /// // Build SELECT * FROM t1 WHRERE id = $1 + /// let plan = table_scan(Some("t1"), &schema, None).unwrap() + /// .filter(col("id").eq(placeholder("$1"))).unwrap() + /// .build().unwrap(); + /// + /// assert_eq!( + /// "Filter: t1.id = $1\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() + /// ); + /// + /// // Fill in the parameter $1 with a literal 3 + /// let plan = plan.with_param_values(vec![ + /// ScalarValue::from(3i32) // value at index 0 --> $1 + /// ]).unwrap(); + /// + /// assert_eq!( + /// "Filter: t1.id = Int32(3)\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() + /// ); + /// + /// // Note you can also used named parameters + /// // Build SELECT * FROM t1 WHRERE id = $my_param + /// let plan = table_scan(Some("t1"), &schema, None).unwrap() + /// .filter(col("id").eq(placeholder("$my_param"))).unwrap() + /// .build().unwrap() + /// // Fill in the parameter $my_param with a literal 3 + /// .with_param_values(vec![ + /// ("my_param", ScalarValue::from(3i32)), + /// ]).unwrap(); + /// + /// assert_eq!( + /// "Filter: t1.id = Int32(3)\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() + /// ); + /// + /// ``` pub fn with_param_values( self, - param_values: Vec, + param_values: impl Into, ) -> Result { + let param_values = param_values.into(); match self { LogicalPlan::Prepare(prepare_lp) => { - // Verify if the number of params matches the number of values - if prepare_lp.data_types.len() != param_values.len() { - return plan_err!( - "Expected {} parameters, got {}", - prepare_lp.data_types.len(), - param_values.len() - ); - } - - // Verify if the types of the params matches the types of the values - let iter = prepare_lp.data_types.iter().zip(param_values.iter()); - for (i, (param_type, value)) in iter.enumerate() { - if *param_type != value.data_type() { - return plan_err!( - "Expected parameter of type {:?}, got {:?} at index {}", - param_type, - value.data_type(), - i - ); - } - } - + param_values.verify(&prepare_lp.data_types)?; let input_plan = prepare_lp.input; input_plan.replace_params_with_values(¶m_values) } - _ => Ok(self), + _ => self.replace_params_with_values(¶m_values), } } @@ -972,7 +999,13 @@ impl LogicalPlan { pub fn max_rows(self: &LogicalPlan) -> Option { match self { LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(), - LogicalPlan::Filter(Filter { input, .. }) => input.max_rows(), + LogicalPlan::Filter(filter) => { + if filter.is_scalar() { + Some(1) + } else { + filter.input.max_rows() + } + } LogicalPlan::Window(Window { input, .. }) => input.max_rows(), LogicalPlan::Aggregate(Aggregate { input, group_expr, .. @@ -1043,7 +1076,9 @@ impl LogicalPlan { LogicalPlan::Subquery(_) => None, LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), LogicalPlan::Limit(Limit { fetch, .. }) => *fetch, - LogicalPlan::Distinct(Distinct { input }) => input.max_rows(), + LogicalPlan::Distinct( + Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + ) => input.max_rows(), LogicalPlan::Values(v) => Some(v.values.len()), LogicalPlan::Unnest(_) => None, LogicalPlan::Ddl(_) @@ -1060,7 +1095,7 @@ impl LogicalPlan { } impl LogicalPlan { - /// applies collect to any subqueries in the plan + /// applies `op` to any subqueries in the plan pub(crate) fn apply_subqueries(&self, op: &mut F) -> datafusion_common::Result<()> where F: FnMut(&Self) -> datafusion_common::Result, @@ -1112,17 +1147,22 @@ impl LogicalPlan { Ok(()) } - /// Return a logical plan with all placeholders/params (e.g $1 $2, - /// ...) replaced with corresponding values provided in the - /// params_values + /// Return a `LogicalPlan` with all placeholders (e.g $1 $2, + /// ...) replaced with corresponding values provided in + /// `params_values` + /// + /// See [`Self::with_param_values`] for examples and usage pub fn replace_params_with_values( &self, - param_values: &[ScalarValue], + param_values: &ParamValues, ) -> Result { let new_exprs = self .expressions() .into_iter() - .map(|e| Self::replace_placeholders_with_values(e, param_values)) + .map(|e| { + let e = e.infer_placeholder_types(self.schema())?; + Self::replace_placeholders_with_values(e, param_values) + }) .collect::>>()?; let new_inputs_with_values = self @@ -1134,7 +1174,7 @@ impl LogicalPlan { self.with_new_exprs(new_exprs, &new_inputs_with_values) } - /// Walk the logical plan, find any `PlaceHolder` tokens, and return a map of their IDs and DataTypes + /// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and DataTypes pub fn get_parameter_types( &self, ) -> Result>, DataFusionError> { @@ -1171,36 +1211,15 @@ impl LogicalPlan { /// corresponding values provided in the params_values fn replace_placeholders_with_values( expr: Expr, - param_values: &[ScalarValue], + param_values: &ParamValues, ) -> Result { expr.transform(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, data_type }) => { - if id.is_empty() || id == "$0" { - return plan_err!("Empty placeholder id"); - } - // convert id (in format $1, $2, ..) to idx (0, 1, ..) - let idx = id[1..].parse::().map_err(|e| { - DataFusionError::Internal(format!( - "Failed to parse placeholder id: {e}" - )) - })? - 1; - // value at the idx-th position in param_values should be the value for the placeholder - let value = param_values.get(idx).ok_or_else(|| { - DataFusionError::Internal(format!( - "No value found for placeholder with id {id}" - )) - })?; - // check if the data type of the value matches the data type of the placeholder - if Some(value.data_type()) != *data_type { - return internal_err!( - "Placeholder value type mismatch: expected {:?}, got {:?}", - data_type, - value.data_type() - ); - } + let value = param_values + .get_placeholders_with_values(id, data_type.as_ref())?; // Replace the placeholder with the value - Ok(Transformed::Yes(Expr::Literal(value.clone()))) + Ok(Transformed::Yes(Expr::Literal(value))) } Expr::ScalarSubquery(qry) => { let subquery = @@ -1219,7 +1238,9 @@ impl LogicalPlan { // Various implementations for printing out LogicalPlans impl LogicalPlan { /// Return a `format`able structure that produces a single line - /// per node. For example: + /// per node. + /// + /// # Example /// /// ```text /// Projection: employee.id @@ -1639,9 +1660,21 @@ impl LogicalPlan { LogicalPlan::Statement(statement) => { write!(f, "{}", statement.display()) } - LogicalPlan::Distinct(Distinct { .. }) => { - write!(f, "Distinct:") - } + LogicalPlan::Distinct(distinct) => match distinct { + Distinct::All(_) => write!(f, "Distinct:"), + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + }) => write!( + f, + "DistinctOn: on_expr=[[{}]], select_expr=[[{}]], sort_expr=[[{}]]", + expr_vec_fmt!(on_expr), + expr_vec_fmt!(select_expr), + if let Some(sort_expr) = sort_expr { expr_vec_fmt!(sort_expr) } else { "".to_string() }, + ), + }, LogicalPlan::Explain { .. } => write!(f, "Explain"), LogicalPlan::Analyze { .. } => write!(f, "Analyze"), LogicalPlan::Union(_) => write!(f, "Union"), @@ -1713,11 +1746,8 @@ pub struct Projection { impl Projection { /// Create a new Projection pub fn try_new(expr: Vec, input: Arc) -> Result { - let schema = Arc::new(DFSchema::new_with_metadata( - exprlist_to_fields(&expr, &input)?, - input.schema().metadata().clone(), - )?); - Self::try_new_with_schema(expr, input, schema) + let projection_schema = projection_schema(&input, &expr)?; + Self::try_new_with_schema(expr, input, projection_schema) } /// Create a new Projection using the specified output schema @@ -1729,11 +1759,6 @@ impl Projection { if expr.len() != schema.fields().len() { return plan_err!("Projection has mismatch between number of expressions ({}) and number of fields in schema ({})", expr.len(), schema.fields().len()); } - // Update functional dependencies of `input` according to projection - // expressions: - let id_key_groups = calc_func_dependencies_for_project(&expr, &input)?; - let schema = schema.as_ref().clone(); - let schema = Arc::new(schema.with_functional_dependencies(id_key_groups)); Ok(Self { expr, input, @@ -1757,6 +1782,30 @@ impl Projection { } } +/// Computes the schema of the result produced by applying a projection to the input logical plan. +/// +/// # Arguments +/// +/// * `input`: A reference to the input `LogicalPlan` for which the projection schema +/// will be computed. +/// * `exprs`: A slice of `Expr` expressions representing the projection operation to apply. +/// +/// # Returns +/// +/// A `Result` containing an `Arc` representing the schema of the result +/// produced by the projection operation. If the schema computation is successful, +/// the `Result` will contain the schema; otherwise, it will contain an error. +pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result> { + let mut schema = DFSchema::new_with_metadata( + exprlist_to_fields(exprs, input)?, + input.schema().metadata().clone(), + )?; + schema = schema.with_functional_dependencies(calc_func_dependencies_for_project( + exprs, input, + )?)?; + Ok(Arc::new(schema)) +} + /// Aliased subquery #[derive(Clone, PartialEq, Eq, Hash)] // mark non_exhaustive to encourage use of try_new/new() @@ -1772,7 +1821,7 @@ pub struct SubqueryAlias { impl SubqueryAlias { pub fn try_new( - plan: LogicalPlan, + plan: Arc, alias: impl Into, ) -> Result { let alias = alias.into(); @@ -1782,10 +1831,10 @@ impl SubqueryAlias { let func_dependencies = plan.schema().functional_dependencies().clone(); let schema = DFSchemaRef::new( DFSchema::try_from_qualified_schema(&alias, &schema)? - .with_functional_dependencies(func_dependencies), + .with_functional_dependencies(func_dependencies)?, ); Ok(SubqueryAlias { - input: Arc::new(plan), + input: plan, alias, schema, }) @@ -1838,6 +1887,73 @@ impl Filter { Ok(Self { predicate, input }) } + + /// Is this filter guaranteed to return 0 or 1 row in a given instantiation? + /// + /// This function will return `true` if its predicate contains a conjunction of + /// `col(a) = `, where its schema has a unique filter that is covered + /// by this conjunction. + /// + /// For example, for the table: + /// ```sql + /// CREATE TABLE t (a INTEGER PRIMARY KEY, b INTEGER); + /// ``` + /// `Filter(a = 2).is_scalar() == true` + /// , whereas + /// `Filter(b = 2).is_scalar() == false` + /// and + /// `Filter(a = 2 OR b = 2).is_scalar() == false` + fn is_scalar(&self) -> bool { + let schema = self.input.schema(); + + let functional_dependencies = self.input.schema().functional_dependencies(); + let unique_keys = functional_dependencies.iter().filter(|dep| { + let nullable = dep.nullable + && dep + .source_indices + .iter() + .any(|&source| schema.field(source).is_nullable()); + !nullable + && dep.mode == Dependency::Single + && dep.target_indices.len() == schema.fields().len() + }); + + let exprs = split_conjunction(&self.predicate); + let eq_pred_cols: HashSet<_> = exprs + .iter() + .filter_map(|expr| { + let Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) = expr + else { + return None; + }; + // This is a no-op filter expression + if left == right { + return None; + } + + match (left.as_ref(), right.as_ref()) { + (Expr::Column(_), Expr::Column(_)) => None, + (Expr::Column(c), _) | (_, Expr::Column(c)) => { + Some(schema.index_of_column(c).unwrap()) + } + _ => None, + } + }) + .collect(); + + // If we have a functional dependence that is a subset of our predicate, + // this filter is scalar + for key in unique_keys { + if key.source_indices.iter().all(|c| eq_pred_cols.contains(c)) { + return true; + } + } + false + } } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) @@ -1854,9 +1970,10 @@ pub struct Window { impl Window { /// Create a new window operator. pub fn try_new(window_expr: Vec, input: Arc) -> Result { - let mut window_fields: Vec = input.schema().fields().clone(); - window_fields - .extend_from_slice(&exprlist_to_fields(window_expr.iter(), input.as_ref())?); + let fields = input.schema().fields(); + let input_len = fields.len(); + let mut window_fields = fields.clone(); + window_fields.extend_from_slice(&exprlist_to_fields(window_expr.iter(), &input)?); let metadata = input.schema().metadata().clone(); // Update functional dependencies for window: @@ -1864,12 +1981,52 @@ impl Window { input.schema().functional_dependencies().clone(); window_func_dependencies.extend_target_indices(window_fields.len()); + // Since we know that ROW_NUMBER outputs will be unique (i.e. it consists + // of consecutive numbers per partition), we can represent this fact with + // functional dependencies. + let mut new_dependencies = window_expr + .iter() + .enumerate() + .filter_map(|(idx, expr)| { + if let Expr::WindowFunction(WindowFunction { + // Function is ROW_NUMBER + fun: + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::RowNumber, + ), + partition_by, + .. + }) = expr + { + // When there is no PARTITION BY, row number will be unique + // across the entire table. + if partition_by.is_empty() { + return Some(idx + input_len); + } + } + None + }) + .map(|idx| { + FunctionalDependence::new(vec![idx], vec![], false) + .with_mode(Dependency::Single) + }) + .collect::>(); + + if !new_dependencies.is_empty() { + for dependence in new_dependencies.iter_mut() { + dependence.target_indices = (0..window_fields.len()).collect(); + } + // Add the dependency introduced because of ROW_NUMBER window function to the functional dependency + let new_deps = FunctionalDependencies::new(new_dependencies); + window_func_dependencies.extend(new_deps); + } + Ok(Window { input, window_expr, schema: Arc::new( DFSchema::new_with_metadata(window_fields, metadata)? - .with_functional_dependencies(window_func_dependencies), + .with_functional_dependencies(window_func_dependencies)?, ), }) } @@ -1914,6 +2071,61 @@ impl Hash for TableScan { } } +impl TableScan { + /// Initialize TableScan with appropriate schema from the given + /// arguments. + pub fn try_new( + table_name: impl Into, + table_source: Arc, + projection: Option>, + filters: Vec, + fetch: Option, + ) -> Result { + let table_name = table_name.into(); + + if table_name.table().is_empty() { + return plan_err!("table_name cannot be empty"); + } + let schema = table_source.schema(); + let func_dependencies = FunctionalDependencies::new_from_constraints( + table_source.constraints(), + schema.fields.len(), + ); + let projected_schema = projection + .as_ref() + .map(|p| { + let projected_func_dependencies = + func_dependencies.project_functional_dependencies(p, p.len()); + let df_schema = DFSchema::new_with_metadata( + p.iter() + .map(|i| { + DFField::from_qualified( + table_name.clone(), + schema.field(*i).clone(), + ) + }) + .collect(), + schema.metadata().clone(), + )?; + df_schema.with_functional_dependencies(projected_func_dependencies) + }) + .unwrap_or_else(|| { + let df_schema = + DFSchema::try_from_qualified_schema(table_name.clone(), &schema)?; + df_schema.with_functional_dependencies(func_dependencies) + })?; + let projected_schema = Arc::new(projected_schema); + Ok(Self { + table_name, + source: table_source, + projection, + projected_schema, + filters, + fetch, + }) + } +} + /// Apply Cross Join to two logical plans #[derive(Clone, PartialEq, Eq, Hash)] pub struct CrossJoin { @@ -2047,9 +2259,93 @@ pub struct Limit { /// Removes duplicate rows from the input #[derive(Clone, PartialEq, Eq, Hash)] -pub struct Distinct { +pub enum Distinct { + /// Plain `DISTINCT` referencing all selection expressions + All(Arc), + /// The `Postgres` addition, allowing separate control over DISTINCT'd and selected columns + On(DistinctOn), +} + +/// Removes duplicate rows from the input +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct DistinctOn { + /// The `DISTINCT ON` clause expression list + pub on_expr: Vec, + /// The selected projection expression list + pub select_expr: Vec, + /// The `ORDER BY` clause, whose initial expressions must match those of the `ON` clause when + /// present. Note that those matching expressions actually wrap the `ON` expressions with + /// additional info pertaining to the sorting procedure (i.e. ASC/DESC, and NULLS FIRST/LAST). + pub sort_expr: Option>, /// The logical plan that is being DISTINCT'd pub input: Arc, + /// The schema description of the DISTINCT ON output + pub schema: DFSchemaRef, +} + +impl DistinctOn { + /// Create a new `DistinctOn` struct. + pub fn try_new( + on_expr: Vec, + select_expr: Vec, + sort_expr: Option>, + input: Arc, + ) -> Result { + if on_expr.is_empty() { + return plan_err!("No `ON` expressions provided"); + } + + let on_expr = normalize_cols(on_expr, input.as_ref())?; + + let schema = DFSchema::new_with_metadata( + exprlist_to_fields(&select_expr, &input)?, + input.schema().metadata().clone(), + )?; + + let mut distinct_on = DistinctOn { + on_expr, + select_expr, + sort_expr: None, + input, + schema: Arc::new(schema), + }; + + if let Some(sort_expr) = sort_expr { + distinct_on = distinct_on.with_sort_expr(sort_expr)?; + } + + Ok(distinct_on) + } + + /// Try to update `self` with a new sort expressions. + /// + /// Validates that the sort expressions are a super-set of the `ON` expressions. + pub fn with_sort_expr(mut self, sort_expr: Vec) -> Result { + let sort_expr = normalize_cols(sort_expr, self.input.as_ref())?; + + // Check that the left-most sort expressions are the same as the `ON` expressions. + let mut matched = true; + for (on, sort) in self.on_expr.iter().zip(sort_expr.iter()) { + match sort { + Expr::Sort(SortExpr { expr, .. }) => { + if on != &**expr { + matched = false; + break; + } + } + _ => return plan_err!("Not a sort expression: {sort}"), + } + } + + if self.on_expr.len() > sort_expr.len() || !matched { + return plan_err!( + "SELECT DISTINCT ON expressions must match initial ORDER BY expressions" + ); + } + + self.sort_expr = Some(sort_expr); + Ok(self) + } } /// Aggregates its input based on a set of grouping and aggregate @@ -2076,13 +2372,25 @@ impl Aggregate { aggr_expr: Vec, ) -> Result { let group_expr = enumerate_grouping_sets(group_expr)?; + + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let grouping_expr: Vec = grouping_set_to_exprlist(group_expr.as_slice())?; - let all_expr = grouping_expr.iter().chain(aggr_expr.iter()); - let schema = DFSchema::new_with_metadata( - exprlist_to_fields(all_expr, &input)?, - input.schema().metadata().clone(), - )?; + let mut fields = exprlist_to_fields(grouping_expr.iter(), &input)?; + + // Even columns that cannot be null will become nullable when used in a grouping set. + if is_grouping_set { + fields = fields + .into_iter() + .map(|field| field.with_nullable(true)) + .collect::>(); + } + + fields.extend(exprlist_to_fields(aggr_expr.iter(), &input)?); + + let schema = + DFSchema::new_with_metadata(fields, input.schema().metadata().clone())?; Self::try_new_with_schema(input, group_expr, aggr_expr, Arc::new(schema)) } @@ -2116,7 +2424,7 @@ impl Aggregate { calc_func_dependencies_for_aggregate(&group_expr, &input, &schema)?; let new_schema = schema.as_ref().clone(); let schema = Arc::new( - new_schema.with_functional_dependencies(aggregate_func_dependencies), + new_schema.with_functional_dependencies(aggregate_func_dependencies)?, ); Ok(Self { input, @@ -2125,6 +2433,13 @@ impl Aggregate { schema, }) } + + /// Get the length of the group by expression in the output schema + /// This is not simply group by expression length. Expression may be + /// GroupingSet, etc. In these case we need to get inner expression lengths. + pub fn group_expr_len(&self) -> Result { + grouping_set_expr_count(&self.group_expr) + } } /// Checks whether any expression in `group_expr` contains `Expr::GroupingSet`. @@ -2319,13 +2634,19 @@ pub struct Unnest { #[cfg(test)] mod tests { + use std::collections::HashMap; + use std::sync::Arc; + use super::*; + use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; - use crate::{col, exists, in_subquery, lit}; + use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeVisitor; - use datafusion_common::{not_impl_err, DFSchema, TableReference}; - use std::collections::HashMap; + use datafusion_common::{ + not_impl_err, Constraint, DFSchema, ScalarValue, TableReference, + }; fn employee_schema() -> Schema { Schema::new(vec![ @@ -2767,15 +3088,13 @@ digraph { let plan = table_scan(TableReference::none(), &schema, None) .unwrap() - .filter(col("id").eq(Expr::Placeholder(Placeholder::new( - "".into(), - Some(DataType::Int32), - )))) + .filter(col("id").eq(placeholder(""))) .unwrap() .build() .unwrap(); - plan.replace_params_with_values(&[42i32.into()]) + let param_values = vec![ScalarValue::Int32(Some(42))]; + plan.replace_params_with_values(¶m_values.clone().into()) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); // test $0 placeholder @@ -2783,15 +3102,159 @@ digraph { let plan = table_scan(TableReference::none(), &schema, None) .unwrap() - .filter(col("id").eq(Expr::Placeholder(Placeholder::new( - "$0".into(), - Some(DataType::Int32), - )))) + .filter(col("id").eq(placeholder("$0"))) .unwrap() .build() .unwrap(); - plan.replace_params_with_values(&[42i32.into()]) + plan.replace_params_with_values(¶m_values.clone().into()) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); + + // test $00 placeholder + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .filter(col("id").eq(placeholder("$00"))) + .unwrap() + .build() + .unwrap(); + + plan.replace_params_with_values(¶m_values.into()) + .expect_err("unexpectedly succeeded to replace an invalid placeholder"); + } + + #[test] + fn test_nullable_schema_after_grouping_set() { + let schema = Schema::new(vec![ + Field::new("foo", DataType::Int32, false), + Field::new("bar", DataType::Int32, false), + ]); + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .aggregate( + vec![Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![col("foo")], + vec![col("bar")], + ]))], + vec![count(lit(true))], + ) + .unwrap() + .build() + .unwrap(); + + let output_schema = plan.schema(); + + assert!(output_schema + .field_with_name(None, "foo") + .unwrap() + .is_nullable(),); + assert!(output_schema + .field_with_name(None, "bar") + .unwrap() + .is_nullable()); + } + + #[test] + fn test_filter_is_scalar() { + // test empty placeholder + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let source = Arc::new(LogicalTableSource::new(schema)); + let schema = Arc::new( + DFSchema::try_from_qualified_schema( + TableReference::bare("tab"), + &source.schema(), + ) + .unwrap(), + ); + let scan = Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::bare("tab"), + source: source.clone(), + projection: None, + projected_schema: schema.clone(), + filters: vec![], + fetch: None, + })); + let col = schema.field(0).qualified_column(); + + let filter = Filter::try_new( + Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + scan, + ) + .unwrap(); + assert!(!filter.is_scalar()); + let unique_schema = Arc::new( + schema + .as_ref() + .clone() + .with_functional_dependencies( + FunctionalDependencies::new_from_constraints( + Some(&Constraints::new_unverified(vec![Constraint::Unique( + vec![0], + )])), + 1, + ), + ) + .unwrap(), + ); + let scan = Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::bare("tab"), + source, + projection: None, + projected_schema: unique_schema.clone(), + filters: vec![], + fetch: None, + })); + let col = schema.field(0).qualified_column(); + + let filter = Filter::try_new( + Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + scan, + ) + .unwrap(); + assert!(filter.is_scalar()); + } + + #[test] + fn test_transform_explain() { + let schema = Schema::new(vec![ + Field::new("foo", DataType::Int32, false), + Field::new("bar", DataType::Int32, false), + ]); + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .explain(false, false) + .unwrap() + .build() + .unwrap(); + + let external_filter = + col("foo").eq(Expr::Literal(ScalarValue::Boolean(Some(true)))); + + // after transformation, because plan is not the same anymore, + // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs + let plan = plan + .transform(&|plan| match plan { + LogicalPlan::TableScan(table) => { + let filter = Filter::try_new( + external_filter.clone(), + Arc::new(LogicalPlan::TableScan(table)), + ) + .unwrap(); + Ok(Transformed::Yes(LogicalPlan::Filter(filter))) + } + x => Ok(Transformed::No(x)), + }) + .unwrap(); + + let expected = "Explain\ + \n Filter: foo = Boolean(true)\ + \n TableScan: ?table?"; + let actual = format!("{}", plan.display_indent()); + assert_eq!(expected.to_string(), actual) } } diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr/src/operator.rs index 112e29082dba..57888a11d426 100644 --- a/datafusion/expr/src/operator.rs +++ b/datafusion/expr/src/operator.rs @@ -53,9 +53,13 @@ pub enum Operator { And, /// Logical OR, like `||` Or, - /// IS DISTINCT FROM + /// `IS DISTINCT FROM` (see [`distinct`]) + /// + /// [`distinct`]: arrow::compute::kernels::cmp::distinct IsDistinctFrom, - /// IS NOT DISTINCT FROM + /// `IS NOT DISTINCT FROM` (see [`not_distinct`]) + /// + /// [`not_distinct`]: arrow::compute::kernels::cmp::not_distinct IsNotDistinctFrom, /// Case sensitive regex match RegexMatch, @@ -363,6 +367,7 @@ impl ops::Neg for Expr { } } +/// Support `NOT ` fluent style impl Not for Expr { type Output = Self; diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 988fe7c91d4f..729131bd95e1 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -20,46 +20,112 @@ use arrow::datatypes::DataType; +/// Constant that is used as a placeholder for any valid timezone. +/// This is used where a function can accept a timestamp type with any +/// valid timezone, it exists to avoid the need to enumerate all possible +/// timezones. See [`TypeSignature`] for more details. +/// +/// Type coercion always ensures that functions will be executed using +/// timestamp arrays that have a valid time zone. Functions must never +/// return results with this timezone. +pub const TIMEZONE_WILDCARD: &str = "+TZ"; + ///A function's volatility, which defines the functions eligibility for certain optimizations #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] pub enum Volatility { - /// Immutable - An immutable function will always return the same output when given the same - /// input. An example of this is [super::BuiltinScalarFunction::Cos]. + /// An immutable function will always return the same output when given the same + /// input. An example of this is [super::BuiltinScalarFunction::Cos]. DataFusion + /// will attempt to inline immutable functions during planning. Immutable, - /// Stable - A stable function may return different values given the same input across different + /// A stable function may return different values given the same input across different /// queries but must return the same value for a given input within a query. An example of - /// this is [super::BuiltinScalarFunction::Now]. + /// this is [super::BuiltinScalarFunction::Now]. DataFusion + /// will attempt to inline `Stable` functions during planning, when possible. + /// For query `select col1, now() from t1`, it might take a while to execute but + /// `now()` column will be the same for each output row, which is evaluated + /// during planning. Stable, - /// Volatile - A volatile function may change the return value from evaluation to evaluation. + /// A volatile function may change the return value from evaluation to evaluation. /// Multiple invocations of a volatile function may return different results when used in the - /// same query. An example of this is [super::BuiltinScalarFunction::Random]. + /// same query. An example of this is [super::BuiltinScalarFunction::Random]. DataFusion + /// can not evaluate such functions during planning. + /// In the query `select col1, random() from t1`, `random()` function will be evaluated + /// for each output row, resulting in a unique random value for each row. Volatile, } -/// A function's type signature, which defines the function's supported argument types. +/// A function's type signature defines the types of arguments the function supports. +/// +/// Functions typically support only a few different types of arguments compared to the +/// different datatypes in Arrow. To make functions easy to use, when possible DataFusion +/// automatically coerces (add casts to) function arguments so they match the type signature. +/// +/// For example, a function like `cos` may only be implemented for `Float64` arguments. To support a query +/// that calles `cos` with a different argument type, such as `cos(int_column)`, type coercion automatically +/// adds a cast such as `cos(CAST int_column AS DOUBLE)` during planning. +/// +/// # Data Types +/// Types to match are represented using Arrow's [`DataType`]. [`DataType::Timestamp`] has an optional variable +/// timezone specification. To specify a function can handle a timestamp with *ANY* timezone, use +/// the [`TIMEZONE_WILDCARD`]. For example: +/// +/// ``` +/// # use arrow::datatypes::{DataType, TimeUnit}; +/// # use datafusion_expr::{TIMEZONE_WILDCARD, TypeSignature}; +/// let type_signature = TypeSignature::Exact(vec![ +/// // A nanosecond precision timestamp with ANY timezone +/// // matches Timestamp(Nanosecond, Some("+0:00")) +/// // matches Timestamp(Nanosecond, Some("+5:00")) +/// // does not match Timestamp(Nanosecond, None) +/// DataType::Timestamp(TimeUnit::Nanosecond, Some(TIMEZONE_WILDCARD.into())), +/// ]); +/// ``` #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum TypeSignature { - /// arbitrary number of arguments of an common type out of a list of valid types - // A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` + /// One or more arguments of an common type out of a list of valid types. + /// + /// # Examples + /// A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` Variadic(Vec), - /// arbitrary number of arguments of an arbitrary but equal type - // A function such as `array` is `VariadicEqual` - // The first argument decides the type used for coercion + /// One or more arguments of an arbitrary but equal type. + /// DataFusion attempts to coerce all argument types to match the first argument's type + /// + /// # Examples + /// Given types in signature should be coericible to the same final type. + /// A function such as `make_array` is `VariadicEqual`. + /// + /// `make_array(i32, i64) -> make_array(i64, i64)` VariadicEqual, - /// arbitrary number of arguments with arbitrary types + /// One or more arguments with arbitrary types VariadicAny, - /// fixed number of arguments of an arbitrary but equal type out of a list of valid types + /// Fixed number of arguments of an arbitrary but equal type out of a list of valid types. /// /// # Examples /// 1. A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` /// 2. A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` Uniform(usize, Vec), - /// exact number of arguments of an exact type + /// Exact number of arguments of an exact type Exact(Vec), - /// fixed number of arguments of arbitrary types + /// Fixed number of arguments of arbitrary types + /// If a function takes 0 argument, its `TypeSignature` should be `Any(0)` Any(usize), - /// One of a list of signatures + /// Matches exactly one of a list of [`TypeSignature`]s. Coercion is attempted to match + /// the signatures in order, and stops after the first success, if any. + /// + /// # Examples + /// Function `make_array` takes 0 or more arguments with arbitrary types, its `TypeSignature` + /// is `OneOf(vec![Any(0), VariadicAny])`. OneOf(Vec), + /// Specialized Signature for ArrayAppend and similar functions + /// The first argument should be List/LargeList, and the second argument should be non-list or list. + /// The second argument's list dimension should be one dimension less than the first argument's list dimension. + /// List dimension of the List/LargeList is equivalent to the number of List. + /// List dimension of the non-list is 0. + ArrayAndElement, + /// Specialized Signature for ArrayPrepend and similar functions + /// The first argument should be non-list or list, and the second argument should be List/LargeList. + /// The first argument's list dimension should be one dimension less than the second argument's list dimension. + ElementAndArray, } impl TypeSignature { @@ -83,11 +149,19 @@ impl TypeSignature { .collect::>() .join(", ")] } - TypeSignature::VariadicEqual => vec!["T, .., T".to_string()], + TypeSignature::VariadicEqual => { + vec!["CoercibleT, .., CoercibleT".to_string()] + } TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], TypeSignature::OneOf(sigs) => { sigs.iter().flat_map(|s| s.to_string_repr()).collect() } + TypeSignature::ArrayAndElement => { + vec!["ArrayAndElement(List, T)".to_string()] + } + TypeSignature::ElementAndArray => { + vec!["ElementAndArray(T, List)".to_string()] + } } } @@ -102,48 +176,62 @@ impl TypeSignature { .collect::>() .join(delimiter) } + + /// Check whether 0 input argument is valid for given `TypeSignature` + pub fn supports_zero_argument(&self) -> bool { + match &self { + TypeSignature::Exact(vec) => vec.is_empty(), + TypeSignature::Uniform(0, _) | TypeSignature::Any(0) => true, + TypeSignature::OneOf(types) => types + .iter() + .any(|type_sig| type_sig.supports_zero_argument()), + _ => false, + } + } } -/// The signature of a function defines the supported argument types -/// and its volatility. +/// Defines the supported argument types ([`TypeSignature`]) and [`Volatility`] for a function. +/// +/// DataFusion will automatically coerce (cast) argument types to one of the supported +/// function signatures, if possible. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Signature { - /// type_signature - The types that the function accepts. See [TypeSignature] for more information. + /// The data types that the function accepts. See [TypeSignature] for more information. pub type_signature: TypeSignature, - /// volatility - The volatility of the function. See [Volatility] for more information. + /// The volatility of the function. See [Volatility] for more information. pub volatility: Volatility, } impl Signature { - /// new - Creates a new Signature from any type signature and the volatility. + /// Creates a new Signature from a given type signature and volatility. pub fn new(type_signature: TypeSignature, volatility: Volatility) -> Self { Signature { type_signature, volatility, } } - /// variadic - Creates a variadic signature that represents an arbitrary number of arguments all from a type in common_types. + /// An arbitrary number of arguments with the same type, from those listed in `common_types`. pub fn variadic(common_types: Vec, volatility: Volatility) -> Self { Self { type_signature: TypeSignature::Variadic(common_types), volatility, } } - /// variadic_equal - Creates a variadic signature that represents an arbitrary number of arguments of the same type. + /// An arbitrary number of arguments of the same type. pub fn variadic_equal(volatility: Volatility) -> Self { Self { type_signature: TypeSignature::VariadicEqual, volatility, } } - /// variadic_any - Creates a variadic signature that represents an arbitrary number of arguments of any type. + /// An arbitrary number of arguments of any type. pub fn variadic_any(volatility: Volatility) -> Self { Self { type_signature: TypeSignature::VariadicAny, volatility, } } - /// uniform - Creates a function with a fixed number of arguments of the same type, which must be from valid_types. + /// A fixed number of arguments of the same type, from those listed in `valid_types`. pub fn uniform( arg_count: usize, valid_types: Vec, @@ -154,21 +242,21 @@ impl Signature { volatility, } } - /// exact - Creates a signature which must match the types in exact_types in order. + /// Exactly matches the types in `exact_types`, in order. pub fn exact(exact_types: Vec, volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::Exact(exact_types), volatility, } } - /// any - Creates a signature which can a be made of any type but of a specified number + /// A specified number of arguments of any type pub fn any(arg_count: usize, volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::Any(arg_count), volatility, } } - /// one_of Creates a signature which can match any of the [TypeSignature]s which are passed in. + /// Any one of a list of [TypeSignature]s. pub fn one_of(type_signatures: Vec, volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::OneOf(type_signatures), @@ -176,3 +264,59 @@ impl Signature { } } } + +/// Monotonicity of the `ScalarFunctionExpr` with respect to its arguments. +/// Each element of this vector corresponds to an argument and indicates whether +/// the function's behavior is monotonic, or non-monotonic/unknown for that argument, namely: +/// - `None` signifies unknown monotonicity or non-monotonicity. +/// - `Some(true)` indicates that the function is monotonically increasing w.r.t. the argument in question. +/// - Some(false) indicates that the function is monotonically decreasing w.r.t. the argument in question. +pub type FuncMonotonicity = Vec>; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn supports_zero_argument_tests() { + // Testing `TypeSignature`s which supports 0 arg + let positive_cases = vec![ + TypeSignature::Exact(vec![]), + TypeSignature::Uniform(0, vec![DataType::Float64]), + TypeSignature::Any(0), + TypeSignature::OneOf(vec![ + TypeSignature::Exact(vec![DataType::Int8]), + TypeSignature::Any(0), + TypeSignature::Uniform(1, vec![DataType::Int8]), + ]), + ]; + + for case in positive_cases { + assert!( + case.supports_zero_argument(), + "Expected {:?} to support zero arguments", + case + ); + } + + // Testing `TypeSignature`s which doesn't support 0 arg + let negative_cases = vec![ + TypeSignature::Exact(vec![DataType::Utf8]), + TypeSignature::Uniform(1, vec![DataType::Float64]), + TypeSignature::Any(1), + TypeSignature::VariadicAny, + TypeSignature::OneOf(vec![ + TypeSignature::Exact(vec![DataType::Int8]), + TypeSignature::Uniform(1, vec![DataType::Int8]), + ]), + ]; + + for case in negative_cases { + assert!( + !case.supports_zero_argument(), + "Expected {:?} not to support zero arguments", + case + ); + } + } +} diff --git a/datafusion/expr/src/table_source.rs b/datafusion/expr/src/table_source.rs index b83ce778133b..565f48c1c5a9 100644 --- a/datafusion/expr/src/table_source.rs +++ b/datafusion/expr/src/table_source.rs @@ -30,14 +30,14 @@ use std::any::Any; pub enum TableProviderFilterPushDown { /// The expression cannot be used by the provider. Unsupported, - /// The expression can be used to help minimise the data retrieved, - /// but the provider cannot guarantee that all returned tuples - /// satisfy the filter. The Filter plan node containing this expression - /// will be preserved. + /// The expression can be used to reduce the data retrieved, + /// but the provider cannot guarantee it will omit all tuples that + /// may be filtered. In this case, DataFusion will apply an additional + /// `Filter` operation after the scan to ensure all rows are filtered correctly. Inexact, - /// The provider guarantees that all returned data satisfies this - /// filter expression. The Filter plan node containing this expression - /// will be removed. + /// The provider **guarantees** that it will omit **all** tuples that are + /// filtered by the filter expression. This is the fastest option, if available + /// as DataFusion will not apply additional filtering. Exact, } @@ -103,4 +103,9 @@ pub trait TableSource: Sync + Send { fn get_logical_plan(&self) -> Option<&LogicalPlan> { None } + + /// Get the default value for a column, if available. + fn get_column_default(&self, _column: &str) -> Option<&Expr> { + None + } } diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index f74cc164a7a5..56388be58b8a 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -18,21 +18,20 @@ //! Tree node implementation for logical expr use crate::expr::{ - AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Case, Cast, - GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, - ScalarUDF, Sort, TryCast, WindowFunction, + AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case, + Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, + ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; -use crate::Expr; -use datafusion_common::tree_node::VisitRecursion; -use datafusion_common::{tree_node::TreeNode, Result}; +use crate::{Expr, GetFieldAccess}; +use std::borrow::Cow; + +use datafusion_common::tree_node::TreeNode; +use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = match self { - Expr::Alias(Alias{expr,..}) + fn children_nodes(&self) -> Vec> { + match self { + Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::IsNotNull(expr) | Expr::IsTrue(expr) @@ -46,17 +45,26 @@ impl TreeNode for Expr { | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) | Expr::Sort(Sort { expr, .. }) - | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref().clone()], - Expr::GetIndexedField(GetIndexedField { expr, .. }) => { - vec![expr.as_ref().clone()] + | Expr::InSubquery(InSubquery { expr, .. }) => vec![Cow::Borrowed(expr)], + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + let expr = Cow::Borrowed(expr.as_ref()); + match field { + GetFieldAccess::ListIndex { key } => { + vec![Cow::Borrowed(key.as_ref()), expr] + } + GetFieldAccess::ListRange { start, stop } => { + vec![Cow::Borrowed(start), Cow::Borrowed(stop), expr] + } + GetFieldAccess::NamedStructField { name: _name } => { + vec![expr] + } + } } Expr::GroupingSet(GroupingSet::Rollup(exprs)) - | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(), - Expr::ScalarFunction (ScalarFunction{ args, .. } )| Expr::ScalarUDF(ScalarUDF { args, .. }) => { - args.clone() - } + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().map(Cow::Borrowed).collect(), + Expr::ScalarFunction(ScalarFunction { args, .. }) => args.iter().map(Cow::Borrowed).collect(), Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { - lists_of_exprs.clone().into_iter().flatten().collect() + lists_of_exprs.iter().flatten().map(Cow::Borrowed).collect() } Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression @@ -65,46 +73,49 @@ impl TreeNode for Expr { | Expr::Literal(_) | Expr::Exists { .. } | Expr::ScalarSubquery(_) - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } - | Expr::Placeholder (_) => vec![], + | Expr::Wildcard { .. } + | Expr::Placeholder(_) => vec![], Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - vec![left.as_ref().clone(), right.as_ref().clone()] + vec![Cow::Borrowed(left), Cow::Borrowed(right)] } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - vec![expr.as_ref().clone(), pattern.as_ref().clone()] + vec![Cow::Borrowed(expr), Cow::Borrowed(pattern)] } Expr::Between(Between { expr, low, high, .. }) => vec![ - expr.as_ref().clone(), - low.as_ref().clone(), - high.as_ref().clone(), + Cow::Borrowed(expr), + Cow::Borrowed(low), + Cow::Borrowed(high), ], Expr::Case(case) => { let mut expr_vec = vec![]; if let Some(expr) = case.expr.as_ref() { - expr_vec.push(expr.as_ref().clone()); + expr_vec.push(Cow::Borrowed(expr.as_ref())); }; for (when, then) in case.when_then_expr.iter() { - expr_vec.push(when.as_ref().clone()); - expr_vec.push(then.as_ref().clone()); + expr_vec.push(Cow::Borrowed(when)); + expr_vec.push(Cow::Borrowed(then)); } if let Some(else_expr) = case.else_expr.as_ref() { - expr_vec.push(else_expr.as_ref().clone()); + expr_vec.push(Cow::Borrowed(else_expr)); } expr_vec } - Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) - | Expr::AggregateUDF(AggregateUDF { args, filter, order_by, .. }) => { - let mut expr_vec = args.clone(); + Expr::AggregateFunction(AggregateFunction { + args, + filter, + order_by, + .. + }) => { + let mut expr_vec: Vec<_> = args.iter().map(Cow::Borrowed).collect(); if let Some(f) = filter { - expr_vec.push(f.as_ref().clone()); + expr_vec.push(Cow::Borrowed(f)); } if let Some(o) = order_by { - expr_vec.extend(o.clone()); + expr_vec.extend(o.iter().map(Cow::Borrowed).collect::>()); } expr_vec @@ -115,28 +126,17 @@ impl TreeNode for Expr { order_by, .. }) => { - let mut expr_vec = args.clone(); - expr_vec.extend(partition_by.clone()); - expr_vec.extend(order_by.clone()); + let mut expr_vec: Vec<_> = args.iter().map(Cow::Borrowed).collect(); + expr_vec.extend(partition_by.iter().map(Cow::Borrowed).collect::>()); + expr_vec.extend(order_by.iter().map(Cow::Borrowed).collect::>()); expr_vec } Expr::InList(InList { expr, list, .. }) => { - let mut expr_vec = vec![]; - expr_vec.push(expr.as_ref().clone()); - expr_vec.extend(list.clone()); + let mut expr_vec = vec![Cow::Borrowed(expr.as_ref())]; + expr_vec.extend(list.iter().map(Cow::Borrowed).collect::>()); expr_vec } - }; - - for child in children.iter() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } } - - Ok(VisitRecursion::Continue) } fn map_children(self, transform: F) -> Result @@ -146,9 +146,11 @@ impl TreeNode for Expr { let mut transform = transform; Ok(match self { - Expr::Alias(Alias { expr, name, .. }) => { - Expr::Alias(Alias::new(transform(*expr)?, name)) - } + Expr::Alias(Alias { + expr, + relation, + name, + }) => Expr::Alias(Alias::new(transform(*expr)?, relation, name)), Expr::Column(_) => self, Expr::OuterReferenceColumn(_, _) => self, Expr::Exists { .. } => self, @@ -263,12 +265,19 @@ impl TreeNode for Expr { asc, nulls_first, )), - Expr::ScalarFunction(ScalarFunction { args, fun }) => Expr::ScalarFunction( - ScalarFunction::new(fun, transform_vec(args, &mut transform)?), - ), - Expr::ScalarUDF(ScalarUDF { args, fun }) => { - Expr::ScalarUDF(ScalarUDF::new(fun, transform_vec(args, &mut transform)?)) - } + Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => Expr::ScalarFunction( + ScalarFunction::new(fun, transform_vec(args, &mut transform)?), + ), + ScalarFunctionDefinition::UDF(fun) => Expr::ScalarFunction( + ScalarFunction::new_udf(fun, transform_vec(args, &mut transform)?), + ), + ScalarFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + }, Expr::WindowFunction(WindowFunction { args, fun, @@ -284,17 +293,40 @@ impl TreeNode for Expr { )), Expr::AggregateFunction(AggregateFunction { args, - fun, + func_def, distinct, filter, order_by, - }) => Expr::AggregateFunction(AggregateFunction::new( - fun, - transform_vec(args, &mut transform)?, - distinct, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )), + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + Expr::AggregateFunction(AggregateFunction::new( + fun, + transform_vec(args, &mut transform)?, + distinct, + transform_option_box(filter, &mut transform)?, + transform_option_vec(order_by, &mut transform)?, + )) + } + AggregateFunctionDefinition::UDF(fun) => { + let order_by = if let Some(order_by) = order_by { + Some(transform_vec(order_by, &mut transform)?) + } else { + None + }; + Expr::AggregateFunction(AggregateFunction::new_udf( + fun, + transform_vec(args, &mut transform)?, + false, + transform_option_box(filter, &mut transform)?, + transform_option_vec(order_by, &mut transform)?, + )) + } + AggregateFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + }, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup( transform_vec(exprs, &mut transform)?, @@ -311,24 +343,7 @@ impl TreeNode for Expr { )) } }, - Expr::AggregateUDF(AggregateUDF { - args, - fun, - filter, - order_by, - }) => { - let order_by = if let Some(order_by) = order_by { - Some(transform_vec(order_by, &mut transform)?) - } else { - None - }; - Expr::AggregateUDF(AggregateUDF::new( - fun, - transform_vec(args, &mut transform)?, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )) - } + Expr::InList(InList { expr, list, @@ -338,10 +353,7 @@ impl TreeNode for Expr { transform_vec(list, &mut transform)?, negated, )), - Expr::Wildcard => Expr::Wildcard, - Expr::QualifiedWildcard { qualifier } => { - Expr::QualifiedWildcard { qualifier } - } + Expr::Wildcard { qualifier } => Expr::Wildcard { qualifier }, Expr::GetIndexedField(GetIndexedField { expr, field }) => { Expr::GetIndexedField(GetIndexedField::new( transform_boxed(expr, &mut transform)?, diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index c7621bc17833..208a8b57d7b0 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -20,8 +20,13 @@ use crate::LogicalPlan; use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; use datafusion_common::{tree_node::TreeNode, Result}; +use std::borrow::Cow; impl TreeNode for LogicalPlan { + fn children_nodes(&self) -> Vec> { + self.inputs().into_iter().map(Cow::Borrowed).collect() + } + fn apply(&self, op: &mut F) -> Result where F: FnMut(&Self) -> Result, @@ -91,21 +96,6 @@ impl TreeNode for LogicalPlan { visitor.post_visit(self) } - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.inputs() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) - } - fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result, @@ -123,7 +113,7 @@ impl TreeNode for LogicalPlan { .zip(new_children.iter()) .any(|(c1, c2)| c1 != &c2) { - self.with_new_inputs(new_children.as_slice()) + self.with_new_exprs(self.expressions(), new_children.as_slice()) } else { Ok(self) } diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 261c406d5d5e..7128b575978a 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -298,6 +298,23 @@ pub fn coerce_types( | AggregateFunction::FirstValue | AggregateFunction::LastValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), + AggregateFunction::StringAgg => { + if !is_string_agg_supported_arg_type(&input_types[0]) { + return plan_err!( + "The function {:?} does not support inputs of type {:?}", + agg_fun, + input_types[0] + ); + } + if !is_string_agg_supported_arg_type(&input_types[1]) { + return plan_err!( + "The function {:?} does not support inputs of type {:?}", + agg_fun, + input_types[1] + ); + } + Ok(vec![LargeUtf8, input_types[1].clone()]) + } } } @@ -565,6 +582,15 @@ pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool ) } +/// Return `true` if `arg_type` is of a [`DataType`] that the +/// [`AggregateFunction::StringAgg`] aggregation can operate on. +pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Null + ) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 64f814cd958b..6bacc1870079 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -17,20 +17,26 @@ //! Coercion rules for matching argument types for binary operators +use std::sync::Arc; + +use crate::Operator; + use arrow::array::{new_empty_array, Array}; use arrow::compute::can_cast_types; use arrow::datatypes::{ - DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DataType, Field, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; -use datafusion_common::Result; -use datafusion_common::{plan_err, DataFusionError}; - -use crate::type_coercion::is_numeric; -use crate::Operator; +use datafusion_common::{ + exec_datafusion_err, plan_datafusion_err, plan_err, DataFusionError, Result, +}; -/// The type signature of an instantiation of binary expression +/// The type signature of an instantiation of binary operator expression such as +/// `lhs + rhs` +/// +/// Note this is different than [`crate::signature::Signature`] which +/// describes the type signature of a function. struct Signature { /// The type to coerce the left argument to lhs: DataType, @@ -62,83 +68,75 @@ impl Signature { /// Returns a [`Signature`] for applying `op` to arguments of type `lhs` and `rhs` fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result { + use arrow::datatypes::DataType::*; + use Operator::*; match op { - Operator::Eq | - Operator::NotEq | - Operator::Lt | - Operator::LtEq | - Operator::Gt | - Operator::GtEq | - Operator::IsDistinctFrom | - Operator::IsNotDistinctFrom => { + Eq | + NotEq | + Lt | + LtEq | + Gt | + GtEq | + IsDistinctFrom | + IsNotDistinctFrom => { comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { - DataFusionError::Plan(format!( + plan_datafusion_err!( "Cannot infer common argument type for comparison operation {lhs} {op} {rhs}" - )) + ) }) } - Operator::And | Operator::Or => match (lhs, rhs) { - // logical binary boolean operators can only be evaluated in bools or nulls - (DataType::Boolean, DataType::Boolean) - | (DataType::Null, DataType::Null) - | (DataType::Boolean, DataType::Null) - | (DataType::Null, DataType::Boolean) => Ok(Signature::uniform(DataType::Boolean)), - _ => plan_err!( + And | Or => if matches!((lhs, rhs), (Boolean | Null, Boolean | Null)) { + // Logical binary boolean operators can only be evaluated for + // boolean or null arguments. + Ok(Signature::uniform(DataType::Boolean)) + } else { + plan_err!( "Cannot infer common argument type for logical boolean operation {lhs} {op} {rhs}" - ), - }, - Operator::RegexMatch | - Operator::RegexIMatch | - Operator::RegexNotMatch | - Operator::RegexNotIMatch => { + ) + } + RegexMatch | RegexIMatch | RegexNotMatch | RegexNotIMatch => { regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { - DataFusionError::Plan(format!( + plan_datafusion_err!( "Cannot infer common argument type for regex operation {lhs} {op} {rhs}" - )) + ) }) } - Operator::BitwiseAnd - | Operator::BitwiseOr - | Operator::BitwiseXor - | Operator::BitwiseShiftRight - | Operator::BitwiseShiftLeft => { + BitwiseAnd | BitwiseOr | BitwiseXor | BitwiseShiftRight | BitwiseShiftLeft => { bitwise_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { - DataFusionError::Plan(format!( + plan_datafusion_err!( "Cannot infer common type for bitwise operation {lhs} {op} {rhs}" - )) + ) }) } - Operator::StringConcat => { + StringConcat => { string_concat_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { - DataFusionError::Plan(format!( + plan_datafusion_err!( "Cannot infer common string type for string concat operation {lhs} {op} {rhs}" - )) + ) }) } - Operator::AtArrow - | Operator::ArrowAt => { - array_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { - DataFusionError::Plan(format!( + AtArrow | ArrowAt => { + // ArrowAt and AtArrow check for whether one array is contained in another. + // The result type is boolean. Signature::comparison defines this signature. + // Operation has nothing to do with comparison + array_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { + plan_datafusion_err!( "Cannot infer common array type for arrow operation {lhs} {op} {rhs}" - )) + ) }) } - Operator::Plus | - Operator::Minus | - Operator::Multiply | - Operator::Divide| - Operator::Modulo => { + Plus | Minus | Multiply | Divide | Modulo => { let get_result = |lhs, rhs| { use arrow::compute::kernels::numeric::*; let l = new_empty_array(lhs); let r = new_empty_array(rhs); let result = match op { - Operator::Plus => add_wrapping(&l, &r), - Operator::Minus => sub_wrapping(&l, &r), - Operator::Multiply => mul_wrapping(&l, &r), - Operator::Divide => div(&l, &r), - Operator::Modulo => rem(&l, &r), + Plus => add_wrapping(&l, &r), + Minus => sub_wrapping(&l, &r), + Multiply => mul_wrapping(&l, &r), + Divide => div(&l, &r), + Modulo => rem(&l, &r), _ => unreachable!(), }; result.map(|x| x.data_type().clone()) @@ -155,9 +153,9 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result // Temporal arithmetic by first coercing to a common time representation // e.g. Date32 - Timestamp let ret = get_result(&coerced, &coerced).map_err(|e| { - DataFusionError::Plan(format!( + plan_datafusion_err!( "Cannot get result type for temporal operation {coerced} {op} {coerced}: {e}" - )) + ) })?; Ok(Signature{ lhs: coerced.clone(), @@ -167,9 +165,9 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result } else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) { // Decimal arithmetic, e.g. Decimal(10, 2) + Decimal(10, 0) let ret = get_result(&lhs, &rhs).map_err(|e| { - DataFusionError::Plan(format!( + plan_datafusion_err!( "Cannot get result type for decimal operation {lhs} {op} {rhs}: {e}" - )) + ) })?; Ok(Signature{ lhs, @@ -225,7 +223,7 @@ fn math_decimal_coercion( (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => { Some((dec_type.clone(), dec_type.clone())) } - (Decimal128(_, _), Decimal128(_, _)) => { + (Decimal128(_, _), Decimal128(_, _)) | (Decimal256(_, _), Decimal256(_, _)) => { Some((lhs_type.clone(), rhs_type.clone())) } // Unlike with comparison we don't coerce to a decimal in the case of floating point @@ -236,9 +234,6 @@ fn math_decimal_coercion( (Int8 | Int16 | Int32 | Int64, Decimal128(_, _)) => { Some((coerce_numeric_type_to_decimal(lhs_type)?, rhs_type.clone())) } - (Decimal256(_, _), Decimal256(_, _)) => { - Some((lhs_type.clone(), rhs_type.clone())) - } (Decimal256(_, _), Int8 | Int16 | Int32 | Int64) => Some(( lhs_type.clone(), coerce_numeric_type_to_decimal256(rhs_type)?, @@ -310,10 +305,10 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (Utf8, _) if is_numeric(rhs_type) => Some(Utf8), - (LargeUtf8, _) if is_numeric(rhs_type) => Some(LargeUtf8), - (_, Utf8) if is_numeric(lhs_type) => Some(Utf8), - (_, LargeUtf8) if is_numeric(lhs_type) => Some(LargeUtf8), + (Utf8, _) if rhs_type.is_numeric() => Some(Utf8), + (LargeUtf8, _) if rhs_type.is_numeric() => Some(LargeUtf8), + (_, Utf8) if lhs_type.is_numeric() => Some(Utf8), + (_, LargeUtf8) if lhs_type.is_numeric() => Some(LargeUtf8), _ => None, } } @@ -336,26 +331,27 @@ fn string_temporal_coercion( rhs_type: &DataType, ) -> Option { use arrow::datatypes::DataType::*; - match (lhs_type, rhs_type) { - (Utf8, Date32) | (Date32, Utf8) => Some(Date32), - (Utf8, Date64) | (Date64, Utf8) => Some(Date64), - (Utf8, Time32(unit)) | (Time32(unit), Utf8) => { - match is_time_with_valid_unit(Time32(unit.clone())) { - false => None, - true => Some(Time32(unit.clone())), - } - } - (Utf8, Time64(unit)) | (Time64(unit), Utf8) => { - match is_time_with_valid_unit(Time64(unit.clone())) { - false => None, - true => Some(Time64(unit.clone())), - } - } - (Timestamp(_, tz), Utf8) | (Utf8, Timestamp(_, tz)) => { - Some(Timestamp(TimeUnit::Nanosecond, tz.clone())) + + fn match_rule(l: &DataType, r: &DataType) -> Option { + match (l, r) { + // Coerce Utf8/LargeUtf8 to Date32/Date64/Time32/Time64/Timestamp + (Utf8, temporal) | (LargeUtf8, temporal) => match temporal { + Date32 | Date64 => Some(temporal.clone()), + Time32(_) | Time64(_) => { + if is_time_with_valid_unit(temporal.to_owned()) { + Some(temporal.to_owned()) + } else { + None + } + } + Timestamp(_, tz) => Some(Timestamp(TimeUnit::Nanosecond, tz.clone())), + _ => None, + }, + _ => None, } - _ => None, } + + match_rule(lhs_type, rhs_type).or_else(|| match_rule(rhs_type, lhs_type)) } /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation @@ -365,7 +361,7 @@ fn comparison_binary_numeric_coercion( rhs_type: &DataType, ) -> Option { use arrow::datatypes::DataType::*; - if !is_numeric(lhs_type) || !is_numeric(rhs_type) { + if !lhs_type.is_numeric() || !rhs_type.is_numeric() { return None; }; @@ -470,6 +466,54 @@ fn get_wider_decimal_type( } } +/// Returns the wider type among arguments `lhs` and `rhs`. +/// The wider type is the type that can safely represent values from both types +/// without information loss. Returns an Error if types are incompatible. +pub fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { + use arrow::datatypes::DataType::*; + Ok(match (lhs, rhs) { + (lhs, rhs) if lhs == rhs => lhs.clone(), + // Right UInt is larger than left UInt. + (UInt8, UInt16 | UInt32 | UInt64) | (UInt16, UInt32 | UInt64) | (UInt32, UInt64) | + // Right Int is larger than left Int. + (Int8, Int16 | Int32 | Int64) | (Int16, Int32 | Int64) | (Int32, Int64) | + // Right Float is larger than left Float. + (Float16, Float32 | Float64) | (Float32, Float64) | + // Right String is larger than left String. + (Utf8, LargeUtf8) | + // Any right type is wider than a left hand side Null. + (Null, _) => rhs.clone(), + // Left UInt is larger than right UInt. + (UInt16 | UInt32 | UInt64, UInt8) | (UInt32 | UInt64, UInt16) | (UInt64, UInt32) | + // Left Int is larger than right Int. + (Int16 | Int32 | Int64, Int8) | (Int32 | Int64, Int16) | (Int64, Int32) | + // Left Float is larger than right Float. + (Float32 | Float64, Float16) | (Float64, Float32) | + // Left String is larget than right String. + (LargeUtf8, Utf8) | + // Any left type is wider than a right hand side Null. + (_, Null) => lhs.clone(), + (List(lhs_field), List(rhs_field)) => { + let field_type = + get_wider_type(lhs_field.data_type(), rhs_field.data_type())?; + if lhs_field.name() != rhs_field.name() { + return Err(exec_datafusion_err!( + "There is no wider type that can represent both {lhs} and {rhs}." + )); + } + assert_eq!(lhs_field.name(), rhs_field.name()); + let field_name = lhs_field.name(); + let nullable = lhs_field.is_nullable() | rhs_field.is_nullable(); + List(Arc::new(Field::new(field_name, field_type, nullable))) + } + (_, _) => { + return Err(exec_datafusion_err!( + "There is no wider type that can represent both {lhs} and {rhs}." + )); + } + }) +} + /// Convert the numeric data type to the decimal data type. /// Now, we just support the signed integer type and floating-point type. fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option { @@ -563,14 +607,18 @@ fn create_decimal256_type(precision: u8, scale: i8) -> DataType { fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> bool { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (_, Null) => is_numeric(lhs_type), - (Null, _) => is_numeric(rhs_type), + (_, Null) => lhs_type.is_numeric(), + (Null, _) => rhs_type.is_numeric(), (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => { - is_numeric(lhs_value_type) && is_numeric(rhs_value_type) + lhs_value_type.is_numeric() && rhs_value_type.is_numeric() + } + (Dictionary(_, value_type), _) => { + value_type.is_numeric() && rhs_type.is_numeric() } - (Dictionary(_, value_type), _) => is_numeric(value_type) && is_numeric(rhs_type), - (_, Dictionary(_, value_type)) => is_numeric(lhs_type) && is_numeric(value_type), - _ => is_numeric(lhs_type) && is_numeric(rhs_type), + (_, Dictionary(_, value_type)) => { + lhs_type.is_numeric() && value_type.is_numeric() + } + _ => lhs_type.is_numeric() && rhs_type.is_numeric(), } } @@ -619,8 +667,6 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_concat_internal_coercion(from_type, &LargeUtf8) } - // TODO: cast between array elements (#6558) - (List(_), from_type) | (from_type, List(_)) => Some(from_type.to_owned()), _ => None, }) } @@ -645,8 +691,9 @@ fn string_concat_internal_coercion( } } -/// Coercion rules for Strings: the type that both lhs and rhs can be -/// casted to for the purpose of a string computation +/// Coercion rules for string types (Utf8/LargeUtf8): If at least one argument is +/// a string type and both arguments can be coerced into a string type, coerce +/// to string type. fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { @@ -662,8 +709,30 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option } } -/// Coercion rules for Binaries: the type that both lhs and rhs can be -/// casted to for the purpose of a computation +/// Coercion rules for binary (Binary/LargeBinary) to string (Utf8/LargeUtf8): +/// If one argument is binary and the other is a string then coerce to string +/// (e.g. for `like`) +fn binary_to_string_coercion( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Binary, Utf8) => Some(Utf8), + (Binary, LargeUtf8) => Some(LargeUtf8), + (LargeBinary, Utf8) => Some(LargeUtf8), + (LargeBinary, LargeUtf8) => Some(LargeUtf8), + (Utf8, Binary) => Some(Utf8), + (Utf8, LargeBinary) => Some(LargeUtf8), + (LargeUtf8, Binary) => Some(LargeUtf8), + (LargeUtf8, LargeBinary) => Some(LargeUtf8), + _ => None, + } +} + +/// Coercion rules for binary types (Binary/LargeBinary): If at least one argument is +/// a binary type and both arguments can be coerced into a binary type, coerce +/// to binary type. fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { @@ -678,6 +747,7 @@ fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option /// This is a union of string coercion rules and dictionary coercion rules pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_coercion(lhs_type, rhs_type) + .or_else(|| binary_to_string_coercion(lhs_type, rhs_type)) .or_else(|| dictionary_coercion(lhs_type, rhs_type, false)) .or_else(|| null_coercion(lhs_type, rhs_type)) } @@ -711,9 +781,14 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(Interval(MonthDayNano)), (Date64, Date32) | (Date32, Date64) => Some(Date64), + (Timestamp(_, None), Date64) | (Date64, Timestamp(_, None)) => { + Some(Timestamp(Nanosecond, None)) + } + (Timestamp(_, _tz), Date64) | (Date64, Timestamp(_, _tz)) => { + Some(Timestamp(Nanosecond, None)) + } (Timestamp(_, None), Date32) | (Date32, Timestamp(_, None)) => { Some(Timestamp(Nanosecond, None)) } @@ -760,7 +835,7 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { match (lhs_type, rhs_type) { @@ -777,14 +852,11 @@ fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { #[cfg(test)] mod tests { - use arrow::datatypes::DataType; - - use datafusion_common::assert_contains; - use datafusion_common::Result; - + use super::*; use crate::Operator; - use super::*; + use arrow::datatypes::DataType; + use datafusion_common::{assert_contains, Result}; #[test] fn test_coercion_error() -> Result<()> { @@ -914,11 +986,33 @@ mod tests { ); } + /// Test coercion rules for binary operators + /// + /// Applies coercion rules for `$LHS_TYPE $OP $RHS_TYPE` and asserts that the + /// the result type is `$RESULT_TYPE` macro_rules! test_coercion_binary_rule { - ($A_TYPE:expr, $B_TYPE:expr, $OP:expr, $C_TYPE:expr) => {{ - let (lhs, rhs) = get_input_types(&$A_TYPE, &$OP, &$B_TYPE)?; - assert_eq!(lhs, $C_TYPE); - assert_eq!(rhs, $C_TYPE); + ($LHS_TYPE:expr, $RHS_TYPE:expr, $OP:expr, $RESULT_TYPE:expr) => {{ + let (lhs, rhs) = get_input_types(&$LHS_TYPE, &$OP, &$RHS_TYPE)?; + assert_eq!(lhs, $RESULT_TYPE); + assert_eq!(rhs, $RESULT_TYPE); + }}; + } + + /// Test coercion rules for like + /// + /// Applies coercion rules for both + /// * `$LHS_TYPE LIKE $RHS_TYPE` + /// * `$RHS_TYPE LIKE $LHS_TYPE` + /// + /// And asserts the result type is `$RESULT_TYPE` + macro_rules! test_like_rule { + ($LHS_TYPE:expr, $RHS_TYPE:expr, $RESULT_TYPE:expr) => {{ + println!("Coercing {} LIKE {}", $LHS_TYPE, $RHS_TYPE); + let result = like_coercion(&$LHS_TYPE, &$RHS_TYPE); + assert_eq!(result, $RESULT_TYPE); + // reverse the order + let result = like_coercion(&$RHS_TYPE, &$LHS_TYPE); + assert_eq!(result, $RESULT_TYPE); }}; } @@ -945,11 +1039,46 @@ mod tests { } #[test] - fn test_type_coercion() -> Result<()> { - // test like coercion rule - let result = like_coercion(&DataType::Utf8, &DataType::Utf8); - assert_eq!(result, Some(DataType::Utf8)); + fn test_like_coercion() { + // string coerce to strings + test_like_rule!(DataType::Utf8, DataType::Utf8, Some(DataType::Utf8)); + test_like_rule!( + DataType::LargeUtf8, + DataType::Utf8, + Some(DataType::LargeUtf8) + ); + test_like_rule!( + DataType::Utf8, + DataType::LargeUtf8, + Some(DataType::LargeUtf8) + ); + test_like_rule!( + DataType::LargeUtf8, + DataType::LargeUtf8, + Some(DataType::LargeUtf8) + ); + // Also coerce binary to strings + test_like_rule!(DataType::Binary, DataType::Utf8, Some(DataType::Utf8)); + test_like_rule!( + DataType::LargeBinary, + DataType::Utf8, + Some(DataType::LargeUtf8) + ); + test_like_rule!( + DataType::Binary, + DataType::LargeUtf8, + Some(DataType::LargeUtf8) + ); + test_like_rule!( + DataType::LargeBinary, + DataType::LargeUtf8, + Some(DataType::LargeUtf8) + ); + } + + #[test] + fn test_type_coercion() -> Result<()> { test_coercion_binary_rule!( DataType::Utf8, DataType::Date32, diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 883ca2b39362..63908d539bd0 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -15,12 +15,16 @@ // specific language governing permissions and limitations // under the License. +use crate::signature::TIMEZONE_WILDCARD; use crate::{Signature, TypeSignature}; use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::utils::list_ndims; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; + +use super::binary::comparison_coercion; /// Performs type coercion for function arguments. /// @@ -34,8 +38,17 @@ pub fn data_types( signature: &Signature, ) -> Result> { if current_types.is_empty() { - return Ok(vec![]); + if signature.type_signature.supports_zero_argument() { + return Ok(vec![]); + } else { + return plan_err!( + "Coercion from {:?} to the signature {:?} failed.", + current_types, + &signature.type_signature + ); + } } + let valid_types = get_valid_types(&signature.type_signature, current_types)?; if valid_types @@ -45,6 +58,8 @@ pub fn data_types( return Ok(current_types.to_vec()); } + // Try and coerce the argument types to match the signature, returning the + // coerced types from the first matching signature. for valid_types in valid_types { if let Some(types) = maybe_data_types(&valid_types, current_types) { return Ok(types); @@ -59,10 +74,60 @@ pub fn data_types( ) } +/// Returns a Vec of all possible valid argument types for the given signature. fn get_valid_types( signature: &TypeSignature, current_types: &[DataType], ) -> Result>> { + fn array_append_or_prepend_valid_types( + current_types: &[DataType], + is_append: bool, + ) -> Result>> { + if current_types.len() != 2 { + return Ok(vec![vec![]]); + } + + let (array_type, elem_type) = if is_append { + (¤t_types[0], ¤t_types[1]) + } else { + (¤t_types[1], ¤t_types[0]) + }; + + // We follow Postgres on `array_append(Null, T)`, which is not valid. + if array_type.eq(&DataType::Null) { + return Ok(vec![vec![]]); + } + + // We need to find the coerced base type, mainly for cases like: + // `array_append(List(null), i64)` -> `List(i64)` + let array_base_type = datafusion_common::utils::base_type(array_type); + let elem_base_type = datafusion_common::utils::base_type(elem_type); + let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); + + if new_base_type.is_none() { + return internal_err!( + "Coercion from {array_base_type:?} to {elem_base_type:?} not supported." + ); + } + let new_base_type = new_base_type.unwrap(); + + let array_type = datafusion_common::utils::coerced_type_with_base_type_only( + array_type, + &new_base_type, + ); + + match array_type { + DataType::List(ref field) | DataType::LargeList(ref field) => { + let elem_type = field.data_type(); + if is_append { + Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]) + } else { + Ok(vec![vec![elem_type.to_owned(), array_type.clone()]]) + } + } + _ => Ok(vec![vec![]]), + } + } let valid_types = match signature { TypeSignature::Variadic(valid_types) => valid_types .iter() @@ -73,16 +138,34 @@ fn get_valid_types( .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) .collect(), TypeSignature::VariadicEqual => { - // one entry with the same len as current_types, whose type is `current_types[0]`. - vec![current_types - .iter() - .map(|_| current_types[0].clone()) - .collect()] + let new_type = current_types.iter().skip(1).try_fold( + current_types.first().unwrap().clone(), + |acc, x| { + let coerced_type = comparison_coercion(&acc, x); + if let Some(coerced_type) = coerced_type { + Ok(coerced_type) + } else { + internal_err!("Coercion from {acc:?} to {x:?} failed.") + } + }, + ); + + match new_type { + Ok(new_type) => vec![vec![new_type; current_types.len()]], + Err(e) => return Err(e), + } } TypeSignature::VariadicAny => { vec![current_types.to_vec()] } + TypeSignature::Exact(valid_types) => vec![valid_types.clone()], + TypeSignature::ArrayAndElement => { + return array_append_or_prepend_valid_types(current_types, true) + } + TypeSignature::ElementAndArray => { + return array_append_or_prepend_valid_types(current_types, false) + } TypeSignature::Any(number) => { if current_types.len() != *number { return plan_err!( @@ -103,7 +186,12 @@ fn get_valid_types( Ok(valid_types) } -/// Try to coerce current_types into valid_types. +/// Try to coerce the current argument types to match the given `valid_types`. +/// +/// For example, if a function `func` accepts arguments of `(int64, int64)`, +/// but was called with `(int32, int64)`, this function could match the +/// valid_types by coercing the first argument to `int64`, and would return +/// `Some([int64, int64])`. fn maybe_data_types( valid_types: &[DataType], current_types: &[DataType], @@ -121,7 +209,7 @@ fn maybe_data_types( } else { // attempt to coerce if let Some(valid_type) = coerced_from(valid_type, current_type) { - new_type.push(valid_type.clone()) + new_type.push(valid_type) } else { // not possible return None; @@ -140,7 +228,7 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { return true; } if let Some(coerced) = coerced_from(type_into, type_from) { - return coerced == type_into; + return coerced == *type_into; } false } @@ -148,15 +236,17 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { fn coerced_from<'a>( type_into: &'a DataType, type_from: &'a DataType, -) -> Option<&'a DataType> { +) -> Option { use self::DataType::*; match type_into { // coerced into type_into - Int8 if matches!(type_from, Null | Int8) => Some(type_into), - Int16 if matches!(type_from, Null | Int8 | Int16 | UInt8) => Some(type_into), + Int8 if matches!(type_from, Null | Int8) => Some(type_into.clone()), + Int16 if matches!(type_from, Null | Int8 | Int16 | UInt8) => { + Some(type_into.clone()) + } Int32 if matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => { - Some(type_into) + Some(type_into.clone()) } Int64 if matches!( @@ -164,13 +254,15 @@ fn coerced_from<'a>( Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 ) => { - Some(type_into) + Some(type_into.clone()) + } + UInt8 if matches!(type_from, Null | UInt8) => Some(type_into.clone()), + UInt16 if matches!(type_from, Null | UInt8 | UInt16) => Some(type_into.clone()), + UInt32 if matches!(type_from, Null | UInt8 | UInt16 | UInt32) => { + Some(type_into.clone()) } - UInt8 if matches!(type_from, Null | UInt8) => Some(type_into), - UInt16 if matches!(type_from, Null | UInt8 | UInt16) => Some(type_into), - UInt32 if matches!(type_from, Null | UInt8 | UInt16 | UInt32) => Some(type_into), UInt64 if matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64) => { - Some(type_into) + Some(type_into.clone()) } Float32 if matches!( @@ -186,7 +278,7 @@ fn coerced_from<'a>( | Float32 ) => { - Some(type_into) + Some(type_into.clone()) } Float64 if matches!( @@ -204,7 +296,7 @@ fn coerced_from<'a>( | Decimal128(_, _) ) => { - Some(type_into) + Some(type_into.clone()) } Timestamp(TimeUnit::Nanosecond, None) if matches!( @@ -212,17 +304,41 @@ fn coerced_from<'a>( Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8 ) => { - Some(type_into) + Some(type_into.clone()) } - Interval(_) if matches!(type_from, Utf8 | LargeUtf8) => Some(type_into), - Utf8 | LargeUtf8 => Some(type_into), - Null if can_cast_types(type_from, type_into) => Some(type_into), + Interval(_) if matches!(type_from, Utf8 | LargeUtf8) => Some(type_into.clone()), + // Any type can be coerced into strings + Utf8 | LargeUtf8 => Some(type_into.clone()), + Null if can_cast_types(type_from, type_into) => Some(type_into.clone()), - // Coerce to consistent timezones, if the `type_from` timezone exists. - Timestamp(TimeUnit::Nanosecond, Some(_)) - if matches!(type_from, Timestamp(TimeUnit::Nanosecond, Some(_))) => + // Only accept list and largelist with the same number of dimensions unless the type is Null. + // List or LargeList with different dimensions should be handled in TypeSignature or other places before this. + List(_) | LargeList(_) + if datafusion_common::utils::base_type(type_from).eq(&Null) + || list_ndims(type_from) == list_ndims(type_into) => { - Some(type_from) + Some(type_into.clone()) + } + + Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => { + match type_from { + Timestamp(_, Some(from_tz)) => { + Some(Timestamp(unit.clone(), Some(from_tz.clone()))) + } + Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => { + // In the absence of any other information assume the time zone is "+00" (UTC). + Some(Timestamp(unit.clone(), Some("+00".into()))) + } + _ => None, + } + } + Timestamp(_, Some(_)) + if matches!( + type_from, + Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8 + ) => + { + Some(type_into.clone()) } // cannot coerce @@ -233,7 +349,7 @@ fn coerced_from<'a>( #[cfg(test)] mod tests { use super::*; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, TimeUnit}; #[test] fn test_maybe_data_types() { @@ -265,6 +381,20 @@ mod tests { vec![DataType::Boolean, DataType::UInt16], Some(vec![DataType::Boolean, DataType::UInt32]), ), + // UTF8 -> Timestamp + ( + vec![ + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())), + DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())), + ], + vec![DataType::Utf8, DataType::Utf8, DataType::Utf8], + Some(vec![ + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())), + DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())), + ]), + ), ]; for case in cases { diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index d72d9c50edd2..86005da3dafa 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -58,15 +58,6 @@ pub fn is_null(dt: &DataType) -> bool { *dt == DataType::Null } -/// Determine whether the given data type `dt` represents numeric values. -pub fn is_numeric(dt: &DataType) -> bool { - is_signed_numeric(dt) - || matches!( - dt, - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 - ) -} - /// Determine whether the given data type `dt` is a `Timestamp`. pub fn is_timestamp(dt: &DataType) -> bool { matches!(dt, DataType::Timestamp(_, _)) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 84e238a1215b..cfbca4ab1337 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. -//! Udaf module contains functions and structs supporting user-defined aggregate functions. +//! [`AggregateUDF`]: User Defined Aggregate Functions -use crate::Expr; +use crate::{Accumulator, Expr}; use crate::{ AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, }; +use arrow::datatypes::DataType; +use datafusion_common::Result; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -46,15 +48,15 @@ use std::sync::Arc; #[derive(Clone)] pub struct AggregateUDF { /// name - pub name: String, + name: String, /// Signature (input arguments) - pub signature: Signature, + signature: Signature, /// Return type - pub return_type: ReturnTypeFunction, + return_type: ReturnTypeFunction, /// actual implementation - pub accumulator: AccumulatorFactoryFunction, + accumulator: AccumulatorFactoryFunction, /// the accumulator's state's description as a function of the return type - pub state_type: StateTypeFunction, + state_type: StateTypeFunction, } impl Debug for AggregateUDF { @@ -105,11 +107,43 @@ impl AggregateUDF { /// This utility allows using the UDAF without requiring access to /// the registry, such as with the DataFrame API. pub fn call(&self, args: Vec) -> Expr { - Expr::AggregateUDF(crate::expr::AggregateUDF { - fun: Arc::new(self.clone()), + Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf( + Arc::new(self.clone()), args, - filter: None, - order_by: None, - }) + false, + None, + None, + )) + } + + /// Returns this function's name + pub fn name(&self) -> &str { + &self.name + } + + /// Returns this function's signature (what input types are accepted) + pub fn signature(&self) -> &Signature { + &self.signature + } + + /// Return the type of the function given its input types + pub fn return_type(&self, args: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(args)?; + Ok(res.as_ref().clone()) + } + + /// Return an accumualator the given aggregate, given + /// its return datatype. + pub fn accumulator(&self, return_type: &DataType) -> Result> { + (self.accumulator)(return_type) + } + + /// Return the type of the intermediate state used by this aggregator, given + /// its return datatype. Supports multi-phase aggregations + pub fn state_type(&self, return_type: &DataType) -> Result> { + // old API returns an Arc for some reason, try and unwrap it here + let res = (self.state_type)(return_type)?; + Ok(Arc::try_unwrap(res).unwrap_or_else(|res| res.as_ref().clone())) } } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index be6c90aa5985..8b35d5834c61 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -15,47 +15,46 @@ // specific language governing permissions and limitations // under the License. -//! Udf module contains foundational types that are used to represent UDFs in DataFusion. +//! [`ScalarUDF`]: Scalar User Defined Functions -use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use crate::{ + ColumnarValue, Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature, +}; +use arrow::datatypes::DataType; +use datafusion_common::Result; +use std::any::Any; use std::fmt; use std::fmt::Debug; use std::fmt::Formatter; use std::sync::Arc; -/// Logical representation of a UDF. -#[derive(Clone)] +/// Logical representation of a Scalar User Defined Function. +/// +/// A scalar function produces a single row output for each row of input. This +/// struct contains the information DataFusion needs to plan and invoke +/// functions you supply such name, type signature, return type, and actual +/// implementation. +/// +/// 1. For simple (less performant) use cases, use [`create_udf`] and [`simple_udf.rs`]. +/// +/// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`]. +/// +/// # API Note +/// +/// This is a separate struct from `ScalarUDFImpl` to maintain backwards +/// compatibility with the older API. +/// +/// [`create_udf`]: crate::expr_fn::create_udf +/// [`simple_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs +/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +#[derive(Debug, Clone)] pub struct ScalarUDF { - /// name - pub name: String, - /// signature - pub signature: Signature, - /// Return type - pub return_type: ReturnTypeFunction, - /// actual implementation - /// - /// The fn param is the wrapped function but be aware that the function will - /// be passed with the slice / vec of columnar values (either scalar or array) - /// with the exception of zero param function, where a singular element vec - /// will be passed. In that case the single element is a null array to indicate - /// the batch's row count (so that the generative zero-argument function can know - /// the result array size). - pub fun: ScalarFunctionImplementation, -} - -impl Debug for ScalarUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("ScalarUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } + inner: Arc, } impl PartialEq for ScalarUDF { fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature + self.name() == other.name() && self.signature() == other.signature() } } @@ -63,30 +62,316 @@ impl Eq for ScalarUDF {} impl std::hash::Hash for ScalarUDF { fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); + self.name().hash(state); + self.signature().hash(state); } } impl ScalarUDF { - /// Create a new ScalarUDF + /// Create a new ScalarUDF from low level details. + /// + /// See [`ScalarUDFImpl`] for a more convenient way to create a + /// `ScalarUDF` using trait objects + #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl instead")] pub fn new( name: &str, signature: &Signature, return_type: &ReturnTypeFunction, fun: &ScalarFunctionImplementation, ) -> Self { - Self { + Self::new_from_impl(ScalarUdfLegacyWrapper { name: name.to_owned(), signature: signature.clone(), return_type: return_type.clone(), fun: fun.clone(), + }) + } + + /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object + /// + /// Note this is the same as using the `From` impl (`ScalarUDF::from`) + pub fn new_from_impl(fun: F) -> ScalarUDF + where + F: ScalarUDFImpl + 'static, + { + Self { + inner: Arc::new(fun), } } - /// creates a logical expression with a call of the UDF + /// Return the underlying [`ScalarUDFImpl`] trait object for this function + pub fn inner(&self) -> Arc { + self.inner.clone() + } + + /// Adds additional names that can be used to invoke this function, in + /// addition to `name` + /// + /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly. + pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { + Self::new_from_impl(AliasedScalarUDFImpl::new(self, aliases)) + } + + /// Returns a [`Expr`] logical expression to call this UDF with specified + /// arguments. + /// /// This utility allows using the UDF without requiring access to the registry. pub fn call(&self, args: Vec) -> Expr { - Expr::ScalarUDF(crate::expr::ScalarUDF::new(Arc::new(self.clone()), args)) + Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf( + Arc::new(self.clone()), + args, + )) + } + + /// Returns this function's name. + /// + /// See [`ScalarUDFImpl::name`] for more details. + pub fn name(&self) -> &str { + self.inner.name() + } + + /// Returns the aliases for this function. + /// + /// See [`ScalarUDF::with_aliases`] for more details + pub fn aliases(&self) -> &[String] { + self.inner.aliases() + } + + /// Returns this function's [`Signature`] (what input types are accepted). + /// + /// See [`ScalarUDFImpl::signature`] for more details. + pub fn signature(&self) -> &Signature { + self.inner.signature() + } + + /// The datatype this function returns given the input argument input types. + /// + /// See [`ScalarUDFImpl::return_type`] for more details. + pub fn return_type(&self, args: &[DataType]) -> Result { + self.inner.return_type(args) + } + + /// Invoke the function on `args`, returning the appropriate result. + /// + /// See [`ScalarUDFImpl::invoke`] for more details. + pub fn invoke(&self, args: &[ColumnarValue]) -> Result { + self.inner.invoke(args) + } + + /// Returns a `ScalarFunctionImplementation` that can invoke the function + /// during execution + pub fn fun(&self) -> ScalarFunctionImplementation { + let captured = self.inner.clone(); + Arc::new(move |args| captured.invoke(args)) + } +} + +impl From for ScalarUDF +where + F: ScalarUDFImpl + Send + Sync + 'static, +{ + fn from(fun: F) -> Self { + Self::new_from_impl(fun) + } +} + +/// Trait for implementing [`ScalarUDF`]. +/// +/// This trait exposes the full API for implementing user defined functions and +/// can be used to implement any function. +/// +/// See [`advanced_udf.rs`] for a full example with complete implementation and +/// [`ScalarUDF`] for other available options. +/// +/// +/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; +/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; +/// #[derive(Debug)] +/// struct AddOne { +/// signature: Signature +/// }; +/// +/// impl AddOne { +/// fn new() -> Self { +/// Self { +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) +/// } +/// } +/// } +/// +/// /// Implement the ScalarUDFImpl trait for AddOne +/// impl ScalarUDFImpl for AddOne { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "add_one" } +/// fn signature(&self) -> &Signature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Int32)) { +/// return plan_err!("add_one only accepts Int32 arguments"); +/// } +/// Ok(DataType::Int32) +/// } +/// // The actual implementation would add one to the argument +/// fn invoke(&self, args: &[ColumnarValue]) -> Result { unimplemented!() } +/// } +/// +/// // Create a new ScalarUDF from the implementation +/// let add_one = ScalarUDF::from(AddOne::new()); +/// +/// // Call the function `add_one(col)` +/// let expr = add_one.call(vec![col("a")]); +/// ``` +pub trait ScalarUDFImpl: Debug + Send + Sync { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns the function's [`Signature`] for information about what input + /// types are accepted and the function's Volatility. + fn signature(&self) -> &Signature; + + /// What [`DataType`] will be returned by this function, given the types of + /// the arguments + fn return_type(&self, arg_types: &[DataType]) -> Result; + + /// Invoke the function on `args`, returning the appropriate result + /// + /// The function will be invoked passed with the slice of [`ColumnarValue`] + /// (either scalar or array). + /// + /// # Zero Argument Functions + /// If the function has zero parameters (e.g. `now()`) it will be passed a + /// single element slice which is a a null array to indicate the batch's row + /// count (so the function can know the resulting array size). + /// + /// # Performance + /// + /// For the best performance, the implementations of `invoke` should handle + /// the common case when one or more of their arguments are constant values + /// (aka [`ColumnarValue::Scalar`]). Calling [`ColumnarValue::into_array`] + /// and treating all arguments as arrays will work, but will be slower. + fn invoke(&self, args: &[ColumnarValue]) -> Result; + + /// Returns any aliases (alternate names) for this function. + /// + /// Aliases can be used to invoke the same function using different names. + /// For example in some databases `now()` and `current_timestamp()` are + /// aliases for the same function. This behavior can be obtained by + /// returning `current_timestamp` as an alias for the `now` function. + /// + /// Note: `aliases` should only include names other than [`Self::name`]. + /// Defaults to `[]` (no aliases) + fn aliases(&self) -> &[String] { + &[] + } +} + +/// ScalarUDF that adds an alias to the underlying function. It is better to +/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible. +#[derive(Debug)] +struct AliasedScalarUDFImpl { + inner: ScalarUDF, + aliases: Vec, +} + +impl AliasedScalarUDFImpl { + pub fn new( + inner: ScalarUDF, + new_aliases: impl IntoIterator, + ) -> Self { + let mut aliases = inner.aliases().to_vec(); + aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); + + Self { inner, aliases } + } +} + +impl ScalarUDFImpl for AliasedScalarUDFImpl { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + self.inner.name() + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner.return_type(arg_types) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + self.inner.invoke(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers +/// of the older API (see +/// for more details) +struct ScalarUdfLegacyWrapper { + /// The name of the function + name: String, + /// The signature (the types of arguments that are supported) + signature: Signature, + /// Function that returns the return type given the argument types + return_type: ReturnTypeFunction, + /// actual implementation + /// + /// The fn param is the wrapped function but be aware that the function will + /// be passed with the slice / vec of columnar values (either scalar or array) + /// with the exception of zero param function, where a singular element vec + /// will be passed. In that case the single element is a null array to indicate + /// the batch's row count (so that the generative zero-argument function can know + /// the result array size). + fun: ScalarFunctionImplementation, +} + +impl Debug for ScalarUdfLegacyWrapper { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("ScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl ScalarUDFImpl for ScalarUdfLegacyWrapper { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(arg_types)?; + Ok(res.as_ref().clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + (self.fun)(args) + } + + fn aliases(&self) -> &[String] { + &[] } } diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index c0a2a8205a08..239a5e24cbf2 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -15,56 +15,52 @@ // specific language governing permissions and limitations // under the License. -//! Support for user-defined window (UDWF) window functions +//! [`WindowUDF`]: User Defined Window Functions +use crate::{ + Expr, PartitionEvaluator, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, + WindowFrame, +}; +use arrow::datatypes::DataType; +use datafusion_common::Result; use std::{ + any::Any, fmt::{self, Debug, Display, Formatter}, sync::Arc, }; -use crate::{ - Expr, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, -}; - /// Logical representation of a user-defined window function (UDWF) /// A UDWF is different from a UDF in that it is stateful across batches. /// /// See the documetnation on [`PartitionEvaluator`] for more details /// +/// 1. For simple (less performant) use cases, use [`create_udwf`] and [`simple_udwf.rs`]. +/// +/// 2. For advanced use cases, use [`WindowUDFImpl`] and [`advanced_udf.rs`]. +/// +/// # API Note +/// This is a separate struct from `WindowUDFImpl` to maintain backwards +/// compatibility with the older API. +/// /// [`PartitionEvaluator`]: crate::PartitionEvaluator -#[derive(Clone)] +/// [`create_udwf`]: crate::expr_fn::create_udwf +/// [`simple_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs +/// [`advanced_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs +#[derive(Debug, Clone)] pub struct WindowUDF { - /// name - pub name: String, - /// signature - pub signature: Signature, - /// Return type - pub return_type: ReturnTypeFunction, - /// Return the partition evaluator - pub partition_evaluator_factory: PartitionEvaluatorFactory, -} - -impl Debug for WindowUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("WindowUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("return_type", &"") - .field("partition_evaluator_factory", &"") - .finish_non_exhaustive() - } + inner: Arc, } /// Defines how the WindowUDF is shown to users impl Display for WindowUDF { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "{}", self.name) + write!(f, "{}", self.name()) } } impl PartialEq for WindowUDF { fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature + self.name() == other.name() && self.signature() == other.signature() } } @@ -72,27 +68,48 @@ impl Eq for WindowUDF {} impl std::hash::Hash for WindowUDF { fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); + self.name().hash(state); + self.signature().hash(state); } } impl WindowUDF { - /// Create a new WindowUDF + /// Create a new WindowUDF from low level details. + /// + /// See [`WindowUDFImpl`] for a more convenient way to create a + /// `WindowUDF` using trait objects + #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl instead")] pub fn new( name: &str, signature: &Signature, return_type: &ReturnTypeFunction, partition_evaluator_factory: &PartitionEvaluatorFactory, ) -> Self { - Self { + Self::new_from_impl(WindowUDFLegacyWrapper { name: name.to_owned(), signature: signature.clone(), return_type: return_type.clone(), partition_evaluator_factory: partition_evaluator_factory.clone(), + }) + } + + /// Create a new `WindowUDF` from a `[WindowUDFImpl]` trait object + /// + /// Note this is the same as using the `From` impl (`WindowUDF::from`) + pub fn new_from_impl(fun: F) -> WindowUDF + where + F: WindowUDFImpl + 'static, + { + Self { + inner: Arc::new(fun), } } + /// Return the underlying [`WindowUDFImpl`] trait object for this function + pub fn inner(&self) -> Arc { + self.inner.clone() + } + /// creates a [`Expr`] that calls the window function given /// the `partition_by`, `order_by`, and `window_frame` definition /// @@ -105,7 +122,7 @@ impl WindowUDF { order_by: Vec, window_frame: WindowFrame, ) -> Expr { - let fun = crate::WindowFunction::WindowUDF(Arc::new(self.clone())); + let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); Expr::WindowFunction(crate::expr::WindowFunction { fun, @@ -115,4 +132,163 @@ impl WindowUDF { window_frame, }) } + + /// Returns this function's name + /// + /// See [`WindowUDFImpl::name`] for more details. + pub fn name(&self) -> &str { + self.inner.name() + } + + /// Returns this function's signature (what input types are accepted) + /// + /// See [`WindowUDFImpl::signature`] for more details. + pub fn signature(&self) -> &Signature { + self.inner.signature() + } + + /// Return the type of the function given its input types + /// + /// See [`WindowUDFImpl::return_type`] for more details. + pub fn return_type(&self, args: &[DataType]) -> Result { + self.inner.return_type(args) + } + + /// Return a `PartitionEvaluator` for evaluating this window function + pub fn partition_evaluator_factory(&self) -> Result> { + self.inner.partition_evaluator() + } +} + +impl From for WindowUDF +where + F: WindowUDFImpl + Send + Sync + 'static, +{ + fn from(fun: F) -> Self { + Self::new_from_impl(fun) + } +} + +/// Trait for implementing [`WindowUDF`]. +/// +/// This trait exposes the full API for implementing user defined window functions and +/// can be used to implement any function. +/// +/// See [`advanced_udwf.rs`] for a full example with complete implementation and +/// [`WindowUDF`] for other available options. +/// +/// +/// [`advanced_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame}; +/// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; +/// #[derive(Debug, Clone)] +/// struct SmoothIt { +/// signature: Signature +/// }; +/// +/// impl SmoothIt { +/// fn new() -> Self { +/// Self { +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) +/// } +/// } +/// } +/// +/// /// Implement the WindowUDFImpl trait for AddOne +/// impl WindowUDFImpl for SmoothIt { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "smooth_it" } +/// fn signature(&self) -> &Signature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Int32)) { +/// return plan_err!("smooth_it only accepts Int32 arguments"); +/// } +/// Ok(DataType::Int32) +/// } +/// // The actual implementation would add one to the argument +/// fn partition_evaluator(&self) -> Result> { unimplemented!() } +/// } +/// +/// // Create a new ScalarUDF from the implementation +/// let smooth_it = WindowUDF::from(SmoothIt::new()); +/// +/// // Call the function `add_one(col)` +/// let expr = smooth_it.call( +/// vec![col("speed")], // smooth_it(speed) +/// vec![col("car")], // PARTITION BY car +/// vec![col("time").sort(true, true)], // ORDER BY time ASC +/// WindowFrame::new(false), +/// ); +/// ``` +pub trait WindowUDFImpl: Debug + Send + Sync { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns the function's [`Signature`] for information about what input + /// types are accepted and the function's Volatility. + fn signature(&self) -> &Signature; + + /// What [`DataType`] will be returned by this function, given the types of + /// the arguments + fn return_type(&self, arg_types: &[DataType]) -> Result; + + /// Invoke the function, returning the [`PartitionEvaluator`] instance + fn partition_evaluator(&self) -> Result>; +} + +/// Implementation of [`WindowUDFImpl`] that wraps the function style pointers +/// of the older API (see +/// for more details) +pub struct WindowUDFLegacyWrapper { + /// name + name: String, + /// signature + signature: Signature, + /// Return type + return_type: ReturnTypeFunction, + /// Return the partition evaluator + partition_evaluator_factory: PartitionEvaluatorFactory, +} + +impl Debug for WindowUDFLegacyWrapper { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("WindowUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("return_type", &"") + .field("partition_evaluator_factory", &"") + .finish_non_exhaustive() + } +} + +impl WindowUDFImpl for WindowUDFLegacyWrapper { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(arg_types)?; + Ok(res.as_ref().clone()) + } + + fn partition_evaluator(&self) -> Result> { + (self.partition_evaluator_factory)() + } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 54a1ce348bf9..914b354d2950 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -17,19 +17,28 @@ //! Expression utilities +use std::cmp::Ordering; +use std::collections::HashSet; +use std::sync::Arc; + use crate::expr::{Alias, Sort, WindowFunction}; +use crate::expr_rewriter::strip_outer_reference; use crate::logical_plan::Aggregate; use crate::signature::{Signature, TypeSignature}; -use crate::{Cast, Expr, ExprSchemable, GroupingSet, LogicalPlan, TryCast}; +use crate::{ + and, BinaryExpr, Cast, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, + Operator, TryCast, +}; + use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::utils::get_at_indices; use datafusion_common::{ - internal_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, - Result, ScalarValue, TableReference, + internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, + DataFusionError, Result, ScalarValue, TableReference, }; + use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions}; -use std::cmp::Ordering; -use std::collections::HashSet; /// The value to which `COUNT(*)` is expanded to in /// `COUNT()` expressions @@ -198,8 +207,8 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { grouping_sets.iter().map(|e| e.iter().collect()).collect() } Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => { - let grouping_sets = - powerset(group_exprs).map_err(DataFusionError::Plan)?; + let grouping_sets = powerset(group_exprs) + .map_err(|e| plan_datafusion_err!("{}", e))?; check_grouping_sets_size_limit(grouping_sets.len())?; grouping_sets } @@ -283,17 +292,14 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::TryCast { .. } | Expr::Sort { .. } | Expr::ScalarFunction(..) - | Expr::ScalarUDF(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } | Expr::GroupingSet(_) - | Expr::AggregateUDF { .. } | Expr::InList { .. } | Expr::Exists { .. } | Expr::InSubquery(_) | Expr::ScalarSubquery(_) - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard { .. } | Expr::GetIndexedField { .. } | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } => {} @@ -420,18 +426,18 @@ pub fn expand_qualified_wildcard( wildcard_options: Option<&WildcardAdditionalOptions>, ) -> Result> { let qualifier = TableReference::from(qualifier); - let qualified_fields: Vec = schema - .fields_with_qualified(&qualifier) - .into_iter() - .cloned() - .collect(); + let qualified_indices = schema.fields_indices_with_qualified(&qualifier); + let projected_func_dependencies = schema + .functional_dependencies() + .project_functional_dependencies(&qualified_indices, qualified_indices.len()); + let qualified_fields = get_at_indices(schema.fields(), &qualified_indices)?; if qualified_fields.is_empty() { return plan_err!("Invalid qualifier {qualifier}"); } let qualified_schema = DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())? // We can use the functional dependencies as is, since it only stores indices: - .with_functional_dependencies(schema.functional_dependencies().clone()); + .with_functional_dependencies(projected_func_dependencies)?; let excluded_columns = if let Some(WildcardAdditionalOptions { opt_exclude, opt_except, @@ -499,7 +505,6 @@ pub fn generate_sort_key( let res = final_sort_keys .into_iter() .zip(is_partition_flag) - .map(|(lhs, rhs)| (lhs, rhs)) .collect::>(); Ok(res) } @@ -570,14 +575,14 @@ pub fn compare_sort_expr( /// group a slice of window expression expr by their order by expressions pub fn group_window_expr_by_sort_keys( - window_expr: &[Expr], -) -> Result)>> { + window_expr: Vec, +) -> Result)>> { let mut result = vec![]; - window_expr.iter().try_for_each(|expr| match expr { - Expr::WindowFunction(WindowFunction{ partition_by, order_by, .. }) => { + window_expr.into_iter().try_for_each(|expr| match &expr { + Expr::WindowFunction( WindowFunction{ partition_by, order_by, .. }) => { let sort_key = generate_sort_key(partition_by, order_by)?; if let Some((_, values)) = result.iter_mut().find( - |group: &&mut (WindowSortKey, Vec<&Expr>)| matches!(group, (key, _) if *key == sort_key), + |group: &&mut (WindowSortKey, Vec)| matches!(group, (key, _) if *key == sort_key), ) { values.push(expr); } else { @@ -592,15 +597,12 @@ pub fn group_window_expr_by_sort_keys( Ok(result) } -/// Collect all deeply nested `Expr::AggregateFunction` and -/// `Expr::AggregateUDF`. They are returned in order of occurrence (depth +/// Collect all deeply nested `Expr::AggregateFunction`. +/// They are returned in order of occurrence (depth /// first), with duplicates omitted. pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { find_exprs_in_exprs(exprs, &|nested_expr| { - matches!( - nested_expr, - Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } - ) + matches!(nested_expr, Expr::AggregateFunction { .. }) }) } @@ -732,11 +734,7 @@ fn agg_cols(agg: &Aggregate) -> Vec { .collect() } -fn exprlist_to_fields_aggregate( - exprs: &[Expr], - plan: &LogicalPlan, - agg: &Aggregate, -) -> Result> { +fn exprlist_to_fields_aggregate(exprs: &[Expr], agg: &Aggregate) -> Result> { let agg_cols = agg_cols(agg); let mut fields = vec![]; for expr in exprs { @@ -745,7 +743,7 @@ fn exprlist_to_fields_aggregate( // resolve against schema of input to aggregate fields.push(expr.to_field(agg.input.schema())?); } - _ => fields.push(expr.to_field(plan.schema())?), + _ => fields.push(expr.to_field(&agg.schema)?), } } Ok(fields) @@ -762,15 +760,7 @@ pub fn exprlist_to_fields<'a>( // `GROUPING(person.state)` so in order to resolve `person.state` in this case we need to // look at the input to the aggregate instead. let fields = match plan { - LogicalPlan::Aggregate(agg) => { - Some(exprlist_to_fields_aggregate(&exprs, plan, agg)) - } - LogicalPlan::Window(window) => match window.input.as_ref() { - LogicalPlan::Aggregate(agg) => { - Some(exprlist_to_fields_aggregate(&exprs, plan, agg)) - } - _ => None, - }, + LogicalPlan::Aggregate(agg) => Some(exprlist_to_fields_aggregate(&exprs, agg)), _ => None, }; if let Some(fields) = fields { @@ -801,9 +791,11 @@ pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { match e { Expr::Column(_) => e, Expr::OuterReferenceColumn(_, _) => e, - Expr::Alias(Alias { expr, name, .. }) => { - columnize_expr(*expr, input_schema).alias(name) - } + Expr::Alias(Alias { + expr, + relation, + name, + }) => columnize_expr(*expr, input_schema).alias_qualified(relation, name), Expr::Cast(Cast { expr, data_type }) => Expr::Cast(Cast { expr: Box::new(columnize_expr(*expr, input_schema)), data_type, @@ -900,7 +892,7 @@ pub fn can_hash(data_type: &DataType) -> bool { DataType::UInt64 => true, DataType::Float32 => true, DataType::Float64 => true, - DataType::Timestamp(time_unit, None) => match time_unit { + DataType::Timestamp(time_unit, _) => match time_unit { TimeUnit::Second => true, TimeUnit::Millisecond => true, TimeUnit::Microsecond => true, @@ -1004,19 +996,251 @@ pub fn generate_signature_error_msg( ) } +/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// See [`split_conjunction_owned`] for more details and an example. +pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { + split_conjunction_impl(expr, vec![]) +} + +fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { + match expr { + Expr::BinaryExpr(BinaryExpr { + right, + op: Operator::And, + left, + }) => { + let exprs = split_conjunction_impl(left, exprs); + split_conjunction_impl(right, exprs) + } + Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs), + other => { + exprs.push(other); + exprs + } + } +} + +/// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// This is often used to "split" filter expressions such as `col1 = 5 +/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::utils::split_conjunction_owned; +/// // a=1 AND b=2 +/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use split_conjunction_owned to split them +/// assert_eq!(split_conjunction_owned(expr), split); +/// ``` +pub fn split_conjunction_owned(expr: Expr) -> Vec { + split_binary_owned(expr, Operator::And) +} + +/// Splits an owned binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` +/// +/// This is often used to "split" expressions such as `col1 = 5 +/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit, Operator}; +/// # use datafusion_expr::utils::split_binary_owned; +/// # use std::ops::Add; +/// // a=1 + b=2 +/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use split_binary_owned to split them +/// assert_eq!(split_binary_owned(expr, Operator::Plus), split); +/// ``` +pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { + split_binary_owned_impl(expr, op, vec![]) +} + +fn split_binary_owned_impl( + expr: Expr, + operator: Operator, + mut exprs: Vec, +) -> Vec { + match expr { + Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { + let exprs = split_binary_owned_impl(*left, operator, exprs); + split_binary_owned_impl(*right, operator, exprs) + } + Expr::Alias(Alias { expr, .. }) => { + split_binary_owned_impl(*expr, operator, exprs) + } + other => { + exprs.push(other); + exprs + } + } +} + +/// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` +/// +/// See [`split_binary_owned`] for more details and an example. +pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { + split_binary_impl(expr, op, vec![]) +} + +fn split_binary_impl<'a>( + expr: &'a Expr, + operator: Operator, + mut exprs: Vec<&'a Expr>, +) -> Vec<&'a Expr> { + match expr { + Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { + let exprs = split_binary_impl(left, operator, exprs); + split_binary_impl(right, operator, exprs) + } + Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs), + other => { + exprs.push(other); + exprs + } + } +} + +/// Combines an array of filter expressions into a single filter +/// expression consisting of the input filter expressions joined with +/// logical AND. +/// +/// Returns None if the filters array is empty. +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::utils::conjunction; +/// // a=1 AND b=2 +/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use conjunction to join them together with `AND` +/// assert_eq!(conjunction(split), Some(expr)); +/// ``` +pub fn conjunction(filters: impl IntoIterator) -> Option { + filters.into_iter().reduce(|accum, expr| accum.and(expr)) +} + +/// Combines an array of filter expressions into a single filter +/// expression consisting of the input filter expressions joined with +/// logical OR. +/// +/// Returns None if the filters array is empty. +pub fn disjunction(filters: impl IntoIterator) -> Option { + filters.into_iter().reduce(|accum, expr| accum.or(expr)) +} + +/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with +/// its predicate be all `predicates` ANDed. +pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result { + // reduce filters to a single filter with an AND + let predicate = predicates + .iter() + .skip(1) + .fold(predicates[0].clone(), |acc, predicate| { + and(acc, (*predicate).to_owned()) + }); + + Ok(LogicalPlan::Filter(Filter::try_new( + predicate, + Arc::new(plan), + )?)) +} + +/// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and +/// one not in the subquery (closed upon from outer scope) +/// +/// # Arguments +/// +/// * `exprs` - List of expressions that may or may not be joins +/// +/// # Return value +/// +/// Tuple of (expressions containing joins, remaining non-join expressions) +pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec, Vec)> { + let mut joins = vec![]; + let mut others = vec![]; + for filter in exprs.into_iter() { + // If the expression contains correlated predicates, add it to join filters + if filter.contains_outer() { + if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right)) + { + joins.push(strip_outer_reference((*filter).clone())); + } + } else { + others.push((*filter).clone()); + } + } + + Ok((joins, others)) +} + +/// Returns the first (and only) element in a slice, or an error +/// +/// # Arguments +/// +/// * `slice` - The slice to extract from +/// +/// # Return value +/// +/// The first element, or an error +pub fn only_or_err(slice: &[T]) -> Result<&T> { + match slice { + [it] => Ok(it), + [] => plan_err!("No items found!"), + _ => plan_err!("More than one item found!"), + } +} + +/// merge inputs schema into a single schema. +pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { + if inputs.len() == 1 { + inputs[0].schema().clone().as_ref().clone() + } else { + inputs.iter().map(|input| input.schema()).fold( + DFSchema::empty(), + |mut lhs, rhs| { + lhs.merge(rhs); + lhs + }, + ) + } +} + #[cfg(test)] mod tests { use super::*; - use crate::expr_vec_fmt; use crate::{ - col, cube, expr, grouping_set, rollup, AggregateFunction, WindowFrame, - WindowFunction, + col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, AggregateFunction, + WindowFrame, WindowFunctionDefinition, }; #[test] fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { - let result = group_window_expr_by_sort_keys(&[])?; - let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![]; + let result = group_window_expr_by_sort_keys(vec![])?; + let expected: Vec<(WindowSortKey, Vec)> = vec![]; assert_eq!(expected, result); Ok(()) } @@ -1024,38 +1248,38 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![], WindowFrame::new(false), )); let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; - let result = group_window_expr_by_sort_keys(exprs)?; + let result = group_window_expr_by_sort_keys(exprs.to_vec())?; let key = vec![]; - let expected: Vec<(WindowSortKey, Vec<&Expr>)> = - vec![(key, vec![&max1, &max2, &min3, &sum4])]; + let expected: Vec<(WindowSortKey, Vec)> = + vec![(key, vec![max1, max2, min3, sum4])]; assert_eq!(expected, result); Ok(()) } @@ -1067,28 +1291,28 @@ mod tests { let created_at_desc = Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], WindowFrame::new(true), )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], WindowFrame::new(true), )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], @@ -1096,7 +1320,7 @@ mod tests { )); // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; - let result = group_window_expr_by_sort_keys(exprs)?; + let result = group_window_expr_by_sort_keys(exprs.to_vec())?; let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)]; let key2 = vec![]; @@ -1106,10 +1330,10 @@ mod tests { (created_at_desc, false), ]; - let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![ - (key1, vec![&max1, &min3]), - (key2, vec![&max2]), - (key3, vec![&sum4]), + let expected: Vec<(WindowSortKey, Vec)> = vec![ + (key1, vec![max1, min3]), + (key2, vec![max2]), + (key3, vec![sum4]), ]; assert_eq!(expected, result); Ok(()) @@ -1119,7 +1343,7 @@ mod tests { fn test_find_sort_exprs() -> Result<()> { let exprs = &[ Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![ @@ -1129,7 +1353,7 @@ mod tests { WindowFrame::new(true), )), Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![ @@ -1322,4 +1546,143 @@ mod tests { Ok(()) } + #[test] + fn test_split_conjunction() { + let expr = col("a"); + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr]); + } + + #[test] + fn test_split_conjunction_two() { + let expr = col("a").eq(lit(5)).and(col("b")); + let expr1 = col("a").eq(lit(5)); + let expr2 = col("b"); + + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr1, &expr2]); + } + + #[test] + fn test_split_conjunction_alias() { + let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias")); + let expr1 = col("a").eq(lit(5)); + let expr2 = col("b"); // has no alias + + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr1, &expr2]); + } + + #[test] + fn test_split_conjunction_or() { + let expr = col("a").eq(lit(5)).or(col("b")); + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr]); + } + + #[test] + fn test_split_binary_owned() { + let expr = col("a"); + assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]); + } + + #[test] + fn test_split_binary_owned_two() { + assert_eq!( + split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), + vec![col("a").eq(lit(5)), col("b")] + ); + } + + #[test] + fn test_split_binary_owned_different_op() { + let expr = col("a").eq(lit(5)).or(col("b")); + assert_eq!( + // expr is connected by OR, but pass in AND + split_binary_owned(expr.clone(), Operator::And), + vec![expr] + ); + } + + #[test] + fn test_split_conjunction_owned() { + let expr = col("a"); + assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + } + + #[test] + fn test_split_conjunction_owned_two() { + assert_eq!( + split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))), + vec![col("a").eq(lit(5)), col("b")] + ); + } + + #[test] + fn test_split_conjunction_owned_alias() { + assert_eq!( + split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), + vec![ + col("a").eq(lit(5)), + // no alias on b + col("b"), + ] + ); + } + + #[test] + fn test_conjunction_empty() { + assert_eq!(conjunction(vec![]), None); + } + + #[test] + fn test_conjunction() { + // `[A, B, C]` + let expr = conjunction(vec![col("a"), col("b"), col("c")]); + + // --> `(A AND B) AND C` + assert_eq!(expr, Some(col("a").and(col("b")).and(col("c")))); + + // which is different than `A AND (B AND C)` + assert_ne!(expr, Some(col("a").and(col("b").and(col("c"))))); + } + + #[test] + fn test_disjunction_empty() { + assert_eq!(disjunction(vec![]), None); + } + + #[test] + fn test_disjunction() { + // `[A, B, C]` + let expr = disjunction(vec![col("a"), col("b"), col("c")]); + + // --> `(A OR B) OR C` + assert_eq!(expr, Some(col("a").or(col("b")).or(col("c")))); + + // which is different than `A OR (B OR C)` + assert_ne!(expr, Some(col("a").or(col("b").or(col("c"))))); + } + + #[test] + fn test_split_conjunction_owned_or() { + let expr = col("a").eq(lit(5)).or(col("b")); + assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + } + + #[test] + fn test_collect_expr() -> Result<()> { + let mut accum: HashSet = HashSet::new(); + expr_to_columns( + &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &mut accum, + )?; + expr_to_columns( + &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &mut accum, + )?; + assert_eq!(1, accum.len()); + assert!(accum.contains(&Column::from_name("a"))); + Ok(()) + } } diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 5f161b85dd9a..2701ca1ecf3b 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -23,6 +23,8 @@ //! - An ending frame boundary, //! - An EXCLUDE clause. +use crate::expr::Sort; +use crate::Expr; use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; use sqlparser::parser::ParserError::ParserError; @@ -142,31 +144,57 @@ impl WindowFrame { } } -/// Construct equivalent explicit window frames for implicit corner cases. -/// With this processing, we may assume in downstream code that RANGE/GROUPS -/// frames contain an appropriate ORDER BY clause. -pub fn regularize(mut frame: WindowFrame, order_bys: usize) -> Result { - if frame.units == WindowFrameUnits::Range && order_bys != 1 { +/// Regularizes ORDER BY clause for window definition for implicit corner cases. +pub fn regularize_window_order_by( + frame: &WindowFrame, + order_by: &mut Vec, +) -> Result<()> { + if frame.units == WindowFrameUnits::Range && order_by.len() != 1 { // Normally, RANGE frames require an ORDER BY clause with exactly one - // column. However, an ORDER BY clause may be absent in two edge cases. + // column. However, an ORDER BY clause may be absent or present but with + // more than one column in two edge cases: + // 1. start bound is UNBOUNDED or CURRENT ROW + // 2. end bound is CURRENT ROW or UNBOUNDED. + // In these cases, we regularize the ORDER BY clause if the ORDER BY clause + // is absent. If an ORDER BY clause is present but has more than one column, + // the ORDER BY clause is unchanged. Note that this follows Postgres behavior. if (frame.start_bound.is_unbounded() || frame.start_bound == WindowFrameBound::CurrentRow) && (frame.end_bound == WindowFrameBound::CurrentRow || frame.end_bound.is_unbounded()) { - if order_bys == 0 { - frame.units = WindowFrameUnits::Rows; - frame.start_bound = - WindowFrameBound::Preceding(ScalarValue::UInt64(None)); - frame.end_bound = WindowFrameBound::Following(ScalarValue::UInt64(None)); + // If an ORDER BY clause is absent, it is equivalent to a ORDER BY clause + // with constant value as sort key. + // If an ORDER BY clause is present but has more than one column, it is + // unchanged. + if order_by.is_empty() { + order_by.push(Expr::Sort(Sort::new( + Box::new(Expr::Literal(ScalarValue::UInt64(Some(1)))), + true, + false, + ))); } - } else { + } + } + Ok(()) +} + +/// Checks if given window frame is valid. In particular, if the frame is RANGE +/// with offset PRECEDING/FOLLOWING, it must have exactly one ORDER BY column. +pub fn check_window_frame(frame: &WindowFrame, order_bys: usize) -> Result<()> { + if frame.units == WindowFrameUnits::Range && order_bys != 1 { + // See `regularize_window_order_by`. + if !(frame.start_bound.is_unbounded() + || frame.start_bound == WindowFrameBound::CurrentRow) + || !(frame.end_bound == WindowFrameBound::CurrentRow + || frame.end_bound.is_unbounded()) + { plan_err!("RANGE requires exactly one ORDER BY column")? } } else if frame.units == WindowFrameUnits::Groups && order_bys == 0 { plan_err!("GROUPS requires an ORDER BY clause")? }; - Ok(frame) + Ok(()) } /// There are five ways to describe starting and ending frame boundaries: diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs deleted file mode 100644 index 1f36ebdd6b54..000000000000 --- a/datafusion/expr/src/window_function.rs +++ /dev/null @@ -1,447 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Window functions provide the ability to perform calculations across -//! sets of rows that are related to the current query row. -//! -//! see also - -use crate::aggregate_function::AggregateFunction; -use crate::type_coercion::functions::data_types; -use crate::utils; -use crate::{AggregateUDF, Signature, TypeSignature, Volatility, WindowUDF}; -use arrow::datatypes::DataType; -use datafusion_common::{plan_err, DataFusionError, Result}; -use std::sync::Arc; -use std::{fmt, str::FromStr}; -use strum_macros::EnumIter; - -/// WindowFunction -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum WindowFunction { - /// A built in aggregate function that leverages an aggregate function - AggregateFunction(AggregateFunction), - /// A a built-in window function - BuiltInWindowFunction(BuiltInWindowFunction), - /// A user defined aggregate function - AggregateUDF(Arc), - /// A user defined aggregate function - WindowUDF(Arc), -} - -/// Find DataFusion's built-in window function by name. -pub fn find_df_window_func(name: &str) -> Option { - let name = name.to_lowercase(); - // Code paths for window functions leveraging ordinary aggregators and - // built-in window functions are quite different, and the same function - // may have different implementations for these cases. If the sought - // function is not found among built-in window functions, we search for - // it among aggregate functions. - if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) { - Some(WindowFunction::BuiltInWindowFunction(built_in_function)) - } else if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) { - Some(WindowFunction::AggregateFunction(aggregate)) - } else { - None - } -} - -impl fmt::Display for BuiltInWindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.name()) - } -} - -impl fmt::Display for WindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - WindowFunction::AggregateFunction(fun) => fun.fmt(f), - WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), - WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), - WindowFunction::WindowUDF(fun) => fun.fmt(f), - } - } -} - -/// A [window function] built in to DataFusion -/// -/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) -#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)] -pub enum BuiltInWindowFunction { - /// number of the current row within its partition, counting from 1 - RowNumber, - /// rank of the current row with gaps; same as row_number of its first peer - Rank, - /// rank of the current row without gaps; this function counts peer groups - DenseRank, - /// relative rank of the current row: (rank - 1) / (total rows - 1) - PercentRank, - /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) - CumeDist, - /// integer ranging from 1 to the argument value, dividing the partition as equally as possible - Ntile, - /// returns value evaluated at the row that is offset rows before the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lag, - /// returns value evaluated at the row that is offset rows after the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lead, - /// returns value evaluated at the row that is the first row of the window frame - FirstValue, - /// returns value evaluated at the row that is the last row of the window frame - LastValue, - /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row - NthValue, -} - -impl BuiltInWindowFunction { - fn name(&self) -> &str { - use BuiltInWindowFunction::*; - match self { - RowNumber => "ROW_NUMBER", - Rank => "RANK", - DenseRank => "DENSE_RANK", - PercentRank => "PERCENT_RANK", - CumeDist => "CUME_DIST", - Ntile => "NTILE", - Lag => "LAG", - Lead => "LEAD", - FirstValue => "FIRST_VALUE", - LastValue => "LAST_VALUE", - NthValue => "NTH_VALUE", - } - } -} - -impl FromStr for BuiltInWindowFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name.to_uppercase().as_str() { - "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, - "RANK" => BuiltInWindowFunction::Rank, - "DENSE_RANK" => BuiltInWindowFunction::DenseRank, - "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, - "CUME_DIST" => BuiltInWindowFunction::CumeDist, - "NTILE" => BuiltInWindowFunction::Ntile, - "LAG" => BuiltInWindowFunction::Lag, - "LEAD" => BuiltInWindowFunction::Lead, - "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, - "LAST_VALUE" => BuiltInWindowFunction::LastValue, - "NTH_VALUE" => BuiltInWindowFunction::NthValue, - _ => return plan_err!("There is no built-in window function named {name}"), - }) - } -} - -/// Returns the datatype of the window function -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::return_type` instead" -)] -pub fn return_type( - fun: &WindowFunction, - input_expr_types: &[DataType], -) -> Result { - fun.return_type(input_expr_types) -} - -impl WindowFunction { - /// Returns the datatype of the window function - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { - match self { - WindowFunction::AggregateFunction(fun) => fun.return_type(input_expr_types), - WindowFunction::BuiltInWindowFunction(fun) => { - fun.return_type(input_expr_types) - } - WindowFunction::AggregateUDF(fun) => { - Ok((*(fun.return_type)(input_expr_types)?).clone()) - } - WindowFunction::WindowUDF(fun) => { - Ok((*(fun.return_type)(input_expr_types)?).clone()) - } - } - } -} - -/// Returns the datatype of the built-in window function -impl BuiltInWindowFunction { - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &self.signature()) - // original errors are all related to wrong function signature - // aggregate them for better error message - .map_err(|_| { - DataFusionError::Plan(utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types, - )) - })?; - - match self { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), - BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { - Ok(DataType::Float64) - } - BuiltInWindowFunction::Ntile => Ok(DataType::UInt32), - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue - | BuiltInWindowFunction::LastValue - | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), - } - } -} - -/// the signatures supported by the function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::signature` instead" -)] -pub fn signature(fun: &WindowFunction) -> Signature { - fun.signature() -} - -impl WindowFunction { - /// the signatures supported by the function `fun`. - pub fn signature(&self) -> Signature { - match self { - WindowFunction::AggregateFunction(fun) => fun.signature(), - WindowFunction::BuiltInWindowFunction(fun) => fun.signature(), - WindowFunction::AggregateUDF(fun) => fun.signature.clone(), - WindowFunction::WindowUDF(fun) => fun.signature.clone(), - } - } -} - -/// the signatures supported by the built-in window function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `BuiltInWindowFunction::signature` instead" -)] -pub fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature { - fun.signature() -} - -impl BuiltInWindowFunction { - /// the signatures supported by the built-in window function `fun`. - pub fn signature(&self) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. - match self { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::PercentRank - | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), - BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { - Signature::one_of( - vec![ - TypeSignature::Any(1), - TypeSignature::Any(2), - TypeSignature::Any(3), - ], - Volatility::Immutable, - ) - } - BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { - Signature::any(1, Volatility::Immutable) - } - BuiltInWindowFunction::Ntile => Signature::any(1, Volatility::Immutable), - BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_count_return_type() -> Result<()> { - let fun = find_df_window_func("count").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::Int64, observed); - - Ok(()) - } - - #[test] - fn test_first_value_return_type() -> Result<()> { - let fun = find_df_window_func("first_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::UInt64, observed); - - Ok(()) - } - - #[test] - fn test_last_value_return_type() -> Result<()> { - let fun = find_df_window_func("last_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lead_return_type() -> Result<()> { - let fun = find_df_window_func("lead").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lag_return_type() -> Result<()> { - let fun = find_df_window_func("lag").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_nth_value_return_type() -> Result<()> { - let fun = find_df_window_func("nth_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_percent_rank_return_type() -> Result<()> { - let fun = find_df_window_func("percent_rank").unwrap(); - let observed = fun.return_type(&[])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_cume_dist_return_type() -> Result<()> { - let fun = find_df_window_func("cume_dist").unwrap(); - let observed = fun.return_type(&[])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_window_function_case_insensitive() -> Result<()> { - let names = vec![ - "row_number", - "rank", - "dense_rank", - "percent_rank", - "cume_dist", - "ntile", - "lag", - "lead", - "first_value", - "last_value", - "nth_value", - "min", - "max", - "count", - "avg", - "sum", - ]; - for name in names { - let fun = find_df_window_func(name).unwrap(); - let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); - assert_eq!(fun, fun2); - assert_eq!(fun.to_string(), name.to_uppercase()); - } - Ok(()) - } - - #[test] - fn test_find_df_window_function() { - assert_eq!( - find_df_window_func("max"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Max)) - ); - assert_eq!( - find_df_window_func("min"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Min)) - ); - assert_eq!( - find_df_window_func("avg"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Avg)) - ); - assert_eq!( - find_df_window_func("cume_dist"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::CumeDist - )) - ); - assert_eq!( - find_df_window_func("first_value"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::FirstValue - )) - ); - assert_eq!( - find_df_window_func("LAST_value"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::LastValue - )) - ); - assert_eq!( - find_df_window_func("LAG"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::Lag - )) - ); - assert_eq!( - find_df_window_func("LEAD"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::Lead - )) - ); - assert_eq!(find_df_window_func("not_exist"), None) - } -} diff --git a/datafusion/expr/src/window_state.rs b/datafusion/expr/src/window_state.rs index 4ea9ecea5fc6..de88396d9b0e 100644 --- a/datafusion/expr/src/window_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -98,7 +98,7 @@ impl WindowAggState { } pub fn new(out_type: &DataType) -> Result { - let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0); + let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0)?; Ok(Self { window_frame_range: Range { start: 0, end: 0 }, window_frame_ctx: None, diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index e1ffcb41ba6e..b350d41d3fe3 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -19,9 +19,9 @@ name = "datafusion-optimizer" description = "DataFusion Query Optimizer" keywords = [ "datafusion", "query", "optimizer" ] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -40,17 +40,17 @@ unicode_expressions = ["datafusion-physical-expr/unicode_expressions"] [dependencies] arrow = { workspace = true } -async-trait = "0.1.41" +async-trait = { workspace = true } chrono = { workspace = true } -datafusion-common = { path = "../common", version = "31.0.0", default-features = false } -datafusion-expr = { path = "../expr", version = "31.0.0" } -datafusion-physical-expr = { path = "../physical-expr", version = "31.0.0", default-features = false } +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-physical-expr = { path = "../physical-expr", version = "34.0.0", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } -itertools = "0.11" -log = "^0.4" -regex-syntax = "0.7.1" +itertools = { workspace = true } +log = { workspace = true } +regex-syntax = "0.8.0" [dev-dependencies] -ctor = "0.2.0" -datafusion-sql = { path = "../sql", version = "31.0.0" } +ctor = { workspace = true } +datafusion-sql = { path = "../sql", version = "34.0.0" } env_logger = "0.10.0" diff --git a/datafusion/optimizer/README.md b/datafusion/optimizer/README.md index c8baae03efa2..4f9e0fb98526 100644 --- a/datafusion/optimizer/README.md +++ b/datafusion/optimizer/README.md @@ -19,7 +19,7 @@ # DataFusion Query Optimizer -[DataFusion](df) is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. DataFusion has modular design, allowing individual crates to be re-used in other projects. @@ -153,7 +153,7 @@ Looking at the `EXPLAIN` output we can see that the optimizer has effectively re | logical_plan | Projection: Int64(3) AS Int64(1) + Int64(2) | | | EmptyRelation | | physical_plan | ProjectionExec: expr=[3 as Int64(1) + Int64(2)] | -| | EmptyExec: produce_one_row=true | +| | PlaceholderRowExec | | | | +---------------+-------------------------------------------------+ ``` @@ -318,7 +318,7 @@ In the following example, the `type_coercion` and `simplify_expressions` passes | logical_plan | Projection: Utf8("3.2") AS foo | | | EmptyRelation | | initial_physical_plan | ProjectionExec: expr=[3.2 as foo] | -| | EmptyExec: produce_one_row=true | +| | PlaceholderRowExec | | | | | physical_plan after aggregate_statistics | SAME TEXT AS ABOVE | | physical_plan after join_selection | SAME TEXT AS ABOVE | @@ -326,7 +326,7 @@ In the following example, the `type_coercion` and `simplify_expressions` passes | physical_plan after repartition | SAME TEXT AS ABOVE | | physical_plan after add_merge_exec | SAME TEXT AS ABOVE | | physical_plan | ProjectionExec: expr=[3.2 as foo] | -| | EmptyExec: produce_one_row=true | +| | PlaceholderRowExec | | | | +------------------------------------------------------------+---------------------------------------------------------------------------+ ``` diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 912ac069e0b6..953716713e41 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -19,12 +19,12 @@ use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::Result; -use datafusion_expr::expr::{AggregateFunction, InSubquery}; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::Expr::ScalarSubquery; use datafusion_expr::{ - aggregate_function, expr, lit, window_function, Aggregate, Expr, Filter, LogicalPlan, + aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, Sort, Subquery, }; use std::sync::Arc; @@ -121,7 +121,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { let new_expr = match old_expr.clone() { Expr::WindowFunction(expr::WindowFunction { fun: - window_function::WindowFunction::AggregateFunction( + expr::WindowFunctionDefinition::AggregateFunction( aggregate_function::AggregateFunction::Count, ), args, @@ -129,32 +129,39 @@ impl TreeNodeRewriter for CountWildcardRewriter { order_by, window_frame, }) if args.len() == 1 => match args[0] { - Expr::Wildcard => Expr::WindowFunction(expr::WindowFunction { - fun: window_function::WindowFunction::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ), - args: vec![lit(COUNT_STAR_EXPANSION)], - partition_by, - order_by, - window_frame, - }), + Expr::Wildcard { qualifier: None } => { + Expr::WindowFunction(expr::WindowFunction { + fun: expr::WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Count, + ), + args: vec![lit(COUNT_STAR_EXPANSION)], + partition_by, + order_by, + window_frame, + }) + } _ => old_expr, }, Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Count, + func_def: + AggregateFunctionDefinition::BuiltIn( + aggregate_function::AggregateFunction::Count, + ), args, distinct, filter, order_by, }) if args.len() == 1 => match args[0] { - Expr::Wildcard => Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Count, - args: vec![lit(COUNT_STAR_EXPANSION)], - distinct, - filter, - order_by, - }), + Expr::Wildcard { qualifier: None } => { + Expr::AggregateFunction(AggregateFunction::new( + aggregate_function::AggregateFunction::Count, + vec![lit(COUNT_STAR_EXPANSION)], + distinct, + filter, + order_by, + )) + } _ => old_expr, }, @@ -221,8 +228,8 @@ mod tests { use datafusion_expr::expr::Sort; use datafusion_expr::{ col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder, - max, out_ref_col, scalar_subquery, AggregateFunction, Expr, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunction, + max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -237,9 +244,9 @@ mod tests { fn test_count_wildcard_on_sort() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("b")], vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? - .sort(vec![count(Expr::Wildcard).sort(true, false)])? + .aggregate(vec![col("b")], vec![count(wildcard())])? + .project(vec![count(wildcard())])? + .sort(vec![count(wildcard()).sort(true, false)])? .build()?; let expected = "Sort: COUNT(*) ASC NULLS LAST [COUNT(*):Int64;N]\ \n Projection: COUNT(*) [COUNT(*):Int64;N]\ @@ -258,8 +265,8 @@ mod tests { col("a"), Arc::new( LogicalPlanBuilder::from(table_scan_t2) - .aggregate(Vec::::new(), vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![count(wildcard())])? + .project(vec![count(wildcard())])? .build()?, ), ))? @@ -282,8 +289,8 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan_t1) .filter(exists(Arc::new( LogicalPlanBuilder::from(table_scan_t2) - .aggregate(Vec::::new(), vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![count(wildcard())])? + .project(vec![count(wildcard())])? .build()?, )))? .build()?; @@ -335,8 +342,8 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Count), - vec![Expr::Wildcard], + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], WindowFrame { @@ -347,7 +354,7 @@ mod tests { end_bound: WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), }, ))])? - .project(vec![count(Expr::Wildcard)])? + .project(vec![count(wildcard())])? .build()?; let expected = "Projection: COUNT(UInt8(1)) AS COUNT(*) [COUNT(*):Int64;N]\ @@ -360,8 +367,8 @@ mod tests { fn test_count_wildcard_on_aggregate() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(Vec::::new(), vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![count(wildcard())])? + .project(vec![count(wildcard())])? .build()?; let expected = "Projection: COUNT(*) [COUNT(*):Int64;N]\ @@ -374,8 +381,8 @@ mod tests { fn test_count_wildcard_on_nesting() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(Vec::::new(), vec![max(count(Expr::Wildcard))])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![max(count(wildcard()))])? + .project(vec![count(wildcard())])? .build()?; let expected = "Projection: COUNT(UInt8(1)) AS COUNT(*) [COUNT(*):Int64;N]\ diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 3d0dabdd377c..90af7aec8293 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -126,7 +126,7 @@ fn generate_projection_expr( )); } } else { - exprs.push(Expr::Wildcard); + exprs.push(Expr::Wildcard { qualifier: None }); } Ok(exprs) } diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 14d5ddf47378..9d47299a5616 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -17,6 +17,7 @@ pub mod count_wildcard_rule; pub mod inline_table_scan; +pub mod rewrite_expr; pub mod subquery; pub mod type_coercion; @@ -37,6 +38,8 @@ use log::debug; use std::sync::Arc; use std::time::Instant; +use self::rewrite_expr::OperatorToFunction; + /// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make /// the plan valid prior to the rest of the DataFusion optimization process. /// @@ -72,6 +75,9 @@ impl Analyzer { pub fn new() -> Self { let rules: Vec> = vec![ Arc::new(InlineTableScan::new()), + // OperatorToFunction should be run before TypeCoercion, since it rewrite based on the argument types (List or Scalar), + // and TypeCoercion may cast the argument types from Scalar to List. + Arc::new(OperatorToFunction::new()), Arc::new(TypeCoercion::new()), Arc::new(CountWildcardRule::new()), ]; diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs new file mode 100644 index 000000000000..8f1c844ed062 --- /dev/null +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -0,0 +1,321 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Analyzer rule for to replace operators with function calls (e.g `||` to array_concat`) + +use std::sync::Arc; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::utils::list_ndims; +use datafusion_common::DFSchema; +use datafusion_common::DFSchemaRef; +use datafusion_common::Result; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr_rewriter::rewrite_preserving_name; +use datafusion_expr::utils::merge_schema; +use datafusion_expr::BuiltinScalarFunction; +use datafusion_expr::Operator; +use datafusion_expr::ScalarFunctionDefinition; +use datafusion_expr::{BinaryExpr, Expr, LogicalPlan}; + +use super::AnalyzerRule; + +#[derive(Default)] +pub struct OperatorToFunction {} + +impl OperatorToFunction { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for OperatorToFunction { + fn name(&self) -> &str { + "operator_to_function" + } + + fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + analyze_internal(&plan) + } +} + +fn analyze_internal(plan: &LogicalPlan) -> Result { + // optimize child plans first + let new_inputs = plan + .inputs() + .iter() + .map(|p| analyze_internal(p)) + .collect::>>()?; + + // get schema representing all available input fields. This is used for data type + // resolution only, so order does not matter here + let mut schema = merge_schema(new_inputs.iter().collect()); + + if let LogicalPlan::TableScan(ts) = plan { + let source_schema = + DFSchema::try_from_qualified_schema(&ts.table_name, &ts.source.schema())?; + schema.merge(&source_schema); + } + + let mut expr_rewrite = OperatorToFunctionRewriter { + schema: Arc::new(schema), + }; + + let new_expr = plan + .expressions() + .into_iter() + .map(|expr| { + // ensure names don't change: + // https://github.com/apache/arrow-datafusion/issues/3555 + rewrite_preserving_name(expr, &mut expr_rewrite) + }) + .collect::>>()?; + + plan.with_new_exprs(new_expr, &new_inputs) +} + +pub(crate) struct OperatorToFunctionRewriter { + pub(crate) schema: DFSchemaRef, +} + +impl TreeNodeRewriter for OperatorToFunctionRewriter { + type N = Expr; + + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::BinaryExpr(BinaryExpr { + ref left, + op, + ref right, + }) => { + if let Some(fun) = rewrite_array_concat_operator_to_func_for_column( + left.as_ref(), + op, + right.as_ref(), + self.schema.as_ref(), + )? + .or_else(|| { + rewrite_array_concat_operator_to_func( + left.as_ref(), + op, + right.as_ref(), + ) + }) { + // Convert &Box -> Expr + let left = (**left).clone(); + let right = (**right).clone(); + return Ok(Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args: vec![left, right], + })); + } + + Ok(expr) + } + _ => Ok(expr), + } + } +} + +/// Summary of the logic below: +/// +/// 1) array || array -> array concat +/// +/// 2) array || scalar -> array append +/// +/// 3) scalar || array -> array prepend +/// +/// 4) (arry concat, array append, array prepend) || array -> array concat +/// +/// 5) (arry concat, array append, array prepend) || scalar -> array append +fn rewrite_array_concat_operator_to_func( + left: &Expr, + op: Operator, + right: &Expr, +) -> Option { + // Convert `Array StringConcat Array` to ScalarFunction::ArrayConcat + + if op != Operator::StringConcat { + return None; + } + + match (left, right) { + // Chain concat operator (a || b) || array, + // (arry concat, array append, array prepend) || array -> array concat + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) => Some(BuiltinScalarFunction::ArrayConcat), + // Chain concat operator (a || b) || scalar, + // (arry concat, array append, array prepend) || scalar -> array append + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat), + args: _left_args, + }), + _scalar, + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend), + args: _left_args, + }), + _scalar, + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend), + args: _left_args, + }), + _scalar, + ) => Some(BuiltinScalarFunction::ArrayAppend), + // array || array -> array concat + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _left_args, + }), + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) => Some(BuiltinScalarFunction::ArrayConcat), + // array || scalar -> array append + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _left_args, + }), + _right_scalar, + ) => Some(BuiltinScalarFunction::ArrayAppend), + // scalar || array -> array prepend + ( + _left_scalar, + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), + args: _right_args, + }), + ) => Some(BuiltinScalarFunction::ArrayPrepend), + + _ => None, + } +} + +/// Summary of the logic below: +/// +/// 1) (arry concat, array append, array prepend) || column -> (array append, array concat) +/// +/// 2) column1 || column2 -> (array prepend, array append, array concat) +fn rewrite_array_concat_operator_to_func_for_column( + left: &Expr, + op: Operator, + right: &Expr, + schema: &DFSchema, +) -> Result> { + if op != Operator::StringConcat { + return Ok(None); + } + + match (left, right) { + // Column cases: + // 1) array_prepend/append/concat || column + ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend), + args: _left_args, + }), + Expr::Column(c), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend), + args: _left_args, + }), + Expr::Column(c), + ) + | ( + Expr::ScalarFunction(ScalarFunction { + func_def: + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat), + args: _left_args, + }), + Expr::Column(c), + ) => { + let d = schema.field_from_column(c)?.data_type(); + let ndim = list_ndims(d); + match ndim { + 0 => Ok(Some(BuiltinScalarFunction::ArrayAppend)), + _ => Ok(Some(BuiltinScalarFunction::ArrayConcat)), + } + } + // 2) select column1 || column2 + (Expr::Column(c1), Expr::Column(c2)) => { + let d1 = schema.field_from_column(c1)?.data_type(); + let d2 = schema.field_from_column(c2)?.data_type(); + let ndim1 = list_ndims(d1); + let ndim2 = list_ndims(d2); + match (ndim1, ndim2) { + (0, _) => Ok(Some(BuiltinScalarFunction::ArrayPrepend)), + (_, 0) => Ok(Some(BuiltinScalarFunction::ArrayAppend)), + _ => Ok(Some(BuiltinScalarFunction::ArrayConcat)), + } + } + _ => Ok(None), + } +} diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 6b8b1020cd6d..7c5b70b19af0 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -16,10 +16,11 @@ // under the License. use crate::analyzer::check_plan; -use crate::utils::{collect_subquery_cols, split_conjunction}; +use crate::utils::collect_subquery_cols; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; +use datafusion_expr::utils::split_conjunction; use datafusion_expr::{ Aggregate, BinaryExpr, Cast, Expr, Filter, Join, JoinType, LogicalPlan, Operator, Window, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 5e239f8e9934..6f1da5f4e6d9 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -24,12 +24,12 @@ use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; use datafusion_common::{ - exec_err, internal_err, plan_err, DFSchema, DFSchemaRef, DataFusionError, Result, - ScalarValue, + exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, + DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::{ - self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, - ScalarUDF, WindowFunction, + self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, + InSubquery, Like, ScalarFunction, WindowFunction, }; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; @@ -41,16 +41,16 @@ use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_expression, get_coerce_type_for_list, }; -use datafusion_expr::type_coercion::{is_datetime, is_numeric, is_utf8_or_large_utf8}; +use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; +use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, - type_coercion, window_function, AggregateFunction, BuiltinScalarFunction, Expr, - LogicalPlan, Operator, Projection, WindowFrame, WindowFrameBound, WindowFrameUnits, + type_coercion, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, + LogicalPlan, Operator, Projection, ScalarFunctionDefinition, Signature, WindowFrame, + WindowFrameBound, WindowFrameUnits, }; -use datafusion_expr::{ExprSchemable, Signature}; use crate::analyzer::AnalyzerRule; -use crate::utils::merge_schema; #[derive(Default)] pub struct TypeCoercion {} @@ -162,11 +162,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; let expr_type = expr.get_type(&self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); - let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(DataFusionError::Plan( - format!( + let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!( "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery" ), - ))?; + )?; let new_subquery = Subquery { subquery: Arc::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, @@ -218,9 +217,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } else { "LIKE" }; - DataFusionError::Plan(format!( + plan_datafusion_err!( "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression" - )) + ) })?; let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?); let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?); @@ -320,58 +319,66 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let case = coerce_case_expression(case, &self.schema)?; Ok(Expr::Case(case)) } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - let new_expr = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - &fun.signature, - )?; - Ok(Expr::ScalarUDF(ScalarUDF::new(fun, new_expr))) - } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let new_args = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - &fun.signature(), - )?; - let new_args = - coerce_arguments_for_fun(new_args.as_slice(), &self.schema, &fun)?; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) - } + Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let new_args = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + &fun.signature(), + )?; + let new_args = coerce_arguments_for_fun( + new_args.as_slice(), + &self.schema, + &fun, + )?; + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + } + ScalarFunctionDefinition::UDF(fun) => { + let new_expr = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + fun.signature(), + )?; + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_expr))) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + }, Expr::AggregateFunction(expr::AggregateFunction { - fun, + func_def, args, distinct, filter, order_by, - }) => { - let new_expr = coerce_agg_exprs_for_signature( - &fun, - &args, - &self.schema, - &fun.signature(), - )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, new_expr, distinct, filter, order_by, - )); - Ok(expr) - } - Expr::AggregateUDF(expr::AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let new_expr = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - &fun.signature, - )?; - let expr = Expr::AggregateUDF(expr::AggregateUDF::new( - fun, new_expr, filter, order_by, - )); - Ok(expr) - } + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let new_expr = coerce_agg_exprs_for_signature( + &fun, + &args, + &self.schema, + &fun.signature(), + )?; + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + fun, new_expr, distinct, filter, order_by, + )); + Ok(expr) + } + AggregateFunctionDefinition::UDF(fun) => { + let new_expr = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + fun.signature(), + )?; + let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + fun, new_expr, false, filter, order_by, + )); + Ok(expr) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + }, Expr::WindowFunction(WindowFunction { fun, args, @@ -383,7 +390,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { coerce_window_frame(window_frame, &self.schema, &order_by)?; let args = match &fun { - window_function::WindowFunction::AggregateFunction(fun) => { + expr::WindowFunctionDefinition::AggregateFunction(fun) => { coerce_agg_exprs_for_signature( fun, &args, @@ -496,7 +503,10 @@ fn coerce_window_frame( let target_type = match window_frame.units { WindowFrameUnits::Range => { if let Some(col_type) = current_types.first() { - if is_numeric(col_type) || is_utf8_or_large_utf8(col_type) { + if col_type.is_numeric() + || is_utf8_or_large_utf8(col_type) + || matches!(col_type, DataType::Null) + { col_type } else if is_datetime(col_type) { &DataType::Interval(IntervalUnit::MonthDayNano) @@ -580,26 +590,6 @@ fn coerce_arguments_for_fun( .collect::>>()?; } - if *fun == BuiltinScalarFunction::MakeArray { - // Find the final data type for the function arguments - let current_types = expressions - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - - let new_type = current_types - .iter() - .skip(1) - .fold(current_types.first().unwrap().clone(), |acc, x| { - comparison_coercion(&acc, x).unwrap_or(acc) - }); - - return expressions - .iter() - .zip(current_types) - .map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema)) - .collect(); - } Ok(expressions) } @@ -608,20 +598,6 @@ fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result expr.clone().cast_to(to_type, schema) } -/// Cast array `expr` to the specified type, if possible -fn cast_array_expr( - expr: &Expr, - from_type: &DataType, - to_type: &DataType, - schema: &DFSchema, -) -> Result { - if from_type.equals_datatype(&DataType::Null) { - Ok(expr.clone()) - } else { - cast_expr(expr, to_type, schema) - } -} - /// Returns the coerced exprs for each `input_exprs`. /// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the /// data type of `input_exprs` need to be coerced. @@ -709,20 +685,20 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { let coerced_type = get_coerce_type_for_case_expression(&when_types, Some(case_type)); coerced_type.ok_or_else(|| { - DataFusionError::Plan(format!( + plan_datafusion_err!( "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \ to common types in CASE WHEN expression" - )) + ) }) }) .transpose()?; let then_else_coerce_type = get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else( || { - DataFusionError::Plan(format!( + plan_datafusion_err!( "Failed to coerce then ({then_types:?}) and else ({else_type:?}) \ to common types in CASE WHEN expression" - )) + ) }, )?; @@ -762,8 +738,10 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { #[cfg(test)] mod test { - use std::sync::Arc; + use std::any::Any; + use std::sync::{Arc, OnceLock}; + use arrow::array::{FixedSizeListArray, Int32Array}; use arrow::datatypes::{DataType, TimeUnit}; use arrow::datatypes::Field; @@ -773,13 +751,13 @@ mod test { use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case, - ColumnarValue, ExprSchemable, Filter, Operator, StateTypeFunction, Subquery, + ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl, StateTypeFunction, + Subquery, }; use datafusion_expr::{ lit, logical_plan::{EmptyRelation, Projection}, - Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, - Signature, Volatility, + Expr, LogicalPlan, ReturnTypeFunction, ScalarUDF, Signature, Volatility, }; use datafusion_physical_expr::expressions::AvgAccumulator; @@ -831,22 +809,37 @@ mod test { assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) } + static TEST_SIGNATURE: OnceLock = OnceLock::new(); + + #[derive(Debug, Clone, Default)] + struct TestScalarUDF {} + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "TestScalarUDF" + } + fn signature(&self) -> &Signature { + TEST_SIGNATURE.get_or_init(|| { + Signature::uniform(1, vec![DataType::Float32], Volatility::Stable) + }) + } + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) + } + } + #[test] fn scalar_udf() -> Result<()> { let empty = empty(); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); - let fun: ScalarFunctionImplementation = - Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a")))); - let udf = Expr::ScalarUDF(expr::ScalarUDF::new( - Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), - &return_type, - &fun, - )), - vec![lit(123_i32)], - )); + + let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit(123_i32)]); let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); let expected = "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation"; @@ -856,26 +849,15 @@ mod test { #[test] fn scalar_udf_invalid_input() -> Result<()> { let empty = empty(); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); - let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!()); - let udf = Expr::ScalarUDF(expr::ScalarUDF::new( - Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Int32], Volatility::Stable), - &return_type, - &fun, - )), - vec![lit("Apple")], - )); + let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit("Apple")]); let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "") .err() .unwrap(); assert_eq!( - "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Int32]) failed.", - err.strip_backtrace() - ); + "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Float32]) failed.", + err.strip_backtrace() + ); Ok(()) } @@ -906,9 +888,10 @@ mod test { Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); - let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( + let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit(10i64)], + false, None, None, )); @@ -933,9 +916,10 @@ mod test { &accumulator, &state_type, ); - let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( + let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit("10")], + false, None, None, )); @@ -991,10 +975,13 @@ mod test { None, None, )); - let err = Projection::try_new(vec![agg_expr], empty).err().unwrap(); + let err = Projection::try_new(vec![agg_expr], empty) + .err() + .unwrap() + .strip_backtrace(); assert_eq!( - "Plan(\"No function matches the given name and argument types 'AVG(Utf8)'. You might need to add explicit type casts.\\n\\tCandidate functions:\\n\\tAVG(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64)\")", - &format!("{err:?}") + "Error during planning: No function matches the given name and argument types 'AVG(Utf8)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tAVG(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64)", + err ); Ok(()) } @@ -1235,19 +1222,18 @@ mod test { #[test] fn test_casting_for_fixed_size_list() -> Result<()> { - let val = lit(ScalarValue::Fixedsizelist( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - Arc::new(Field::new("item", DataType::Int32, true)), - 3, + let val = lit(ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + 3, + Arc::new(Int32Array::from(vec![1, 2, 3])), + None, + ), + ))); + let expr = Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + vec![val.clone()], )); - let expr = Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::MakeArray, - args: vec![val.clone()], - }); let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified( "item", @@ -1276,10 +1262,10 @@ mod test { &schema, )?; - let expected = Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::MakeArray, - args: vec![expected_casted_expr], - }); + let expected = Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + vec![expected_casted_expr], + )); assert_eq!(result, expected); Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index c6b138f8ca36..1e089257c61a 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -20,6 +20,8 @@ use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; +use crate::{utils, OptimizerConfig, OptimizerRule}; + use arrow::datatypes::DataType; use datafusion_common::tree_node::{ RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, @@ -27,14 +29,11 @@ use datafusion_common::tree_node::{ use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; -use datafusion_expr::expr::Alias; -use datafusion_expr::{ - col, - logical_plan::{Aggregate, Filter, LogicalPlan, Projection, Sort, Window}, - Expr, ExprSchemable, +use datafusion_expr::expr::{is_volatile, Alias}; +use datafusion_expr::logical_plan::{ + Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; - -use crate::{utils, OptimizerConfig, OptimizerRule}; +use datafusion_expr::{col, Expr, ExprSchemable}; /// A map from expression's identifier to tuple including /// - the expression itself (cloned) @@ -111,24 +110,19 @@ impl CommonSubexprEliminate { projection: &Projection, config: &dyn OptimizerConfig, ) -> Result { - let Projection { - expr, - input, - schema, - .. - } = projection; + let Projection { expr, input, .. } = projection; let input_schema = Arc::clone(input.schema()); let mut expr_set = ExprSet::new(); + + // Visit expr list and build expr identifier to occuring count map (`expr_set`). let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?; let (mut new_expr, new_input) = self.rewrite_expr(&[expr], &[&arrays], input, &expr_set, config)?; - Ok(LogicalPlan::Projection(Projection::try_new_with_schema( - pop_expr(&mut new_expr)?, - Arc::new(new_input), - schema.clone(), - )?)) + // Since projection expr changes, schema changes also. Use try_new method. + Projection::try_new(pop_expr(&mut new_expr)?, Arc::new(new_input)) + .map(LogicalPlan::Projection) } fn try_optimize_filter( @@ -201,7 +195,6 @@ impl CommonSubexprEliminate { group_expr, aggr_expr, input, - schema, .. } = aggregate; let mut expr_set = ExprSet::new(); @@ -247,12 +240,17 @@ impl CommonSubexprEliminate { let rewritten = pop_expr(&mut rewritten)?; if affected_id.is_empty() { - Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema( - Arc::new(new_input), - new_group_expr, - new_aggr_expr, - schema.clone(), - )?)) + // Alias aggregation expressions if they have changed + let new_aggr_expr = new_aggr_expr + .iter() + .zip(aggr_expr.iter()) + .map(|(new_expr, old_expr)| { + new_expr.clone().alias_if_changed(old_expr.display_name()?) + }) + .collect::>>()?; + // Since group_epxr changes, schema changes also. Use try_new method. + Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr) + .map(LogicalPlan::Aggregate) } else { let mut agg_exprs = vec![]; @@ -379,7 +377,7 @@ impl OptimizerRule for CommonSubexprEliminate { Ok(Some(build_recover_project_plan( &original_schema, optimized_plan, - ))) + )?)) } plan => Ok(plan), } @@ -470,16 +468,19 @@ fn build_common_expr_project_plan( /// the "intermediate" projection plan built in [build_common_expr_project_plan]. /// /// This is for those plans who don't keep its own output schema like `Filter` or `Sort`. -fn build_recover_project_plan(schema: &DFSchema, input: LogicalPlan) -> LogicalPlan { +fn build_recover_project_plan( + schema: &DFSchema, + input: LogicalPlan, +) -> Result { let col_exprs = schema .fields() .iter() .map(|field| Expr::Column(field.qualified_column())) .collect(); - LogicalPlan::Projection( - Projection::try_new(col_exprs, Arc::new(input)) - .expect("Cannot build projection plan from an invalid schema"), - ) + Ok(LogicalPlan::Projection(Projection::try_new( + col_exprs, + Arc::new(input), + )?)) } fn extract_expressions( @@ -510,15 +511,14 @@ enum ExprMask { /// - [`Sort`](Expr::Sort) /// - [`Wildcard`](Expr::Wildcard) /// - [`AggregateFunction`](Expr::AggregateFunction) - /// - [`AggregateUDF`](Expr::AggregateUDF) Normal, - /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction) and [`AggregateUDF`](Expr::AggregateUDF). + /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction). NormalAndAggregates, } impl ExprMask { - fn ignores(&self, expr: &Expr) -> bool { + fn ignores(&self, expr: &Expr) -> Result { let is_normal_minus_aggregates = matches!( expr, Expr::Literal(..) @@ -526,18 +526,17 @@ impl ExprMask { | Expr::ScalarVariable(..) | Expr::Alias(..) | Expr::Sort { .. } - | Expr::Wildcard + | Expr::Wildcard { .. } ); - let is_aggr = matches!( - expr, - Expr::AggregateFunction(..) | Expr::AggregateUDF { .. } - ); + let is_volatile = is_volatile(expr)?; - match self { - Self::Normal => is_normal_minus_aggregates || is_aggr, - Self::NormalAndAggregates => is_normal_minus_aggregates, - } + let is_aggr = matches!(expr, Expr::AggregateFunction(..)); + + Ok(match self { + Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr, + Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates, + }) } } @@ -629,7 +628,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { let (idx, sub_expr_desc) = self.pop_enter_mark(); // skip exprs should not be recognize. - if self.expr_mask.ignores(expr) { + if self.expr_mask.ignores(expr)? { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); @@ -909,7 +908,7 @@ mod test { let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); let state_type: StateTypeFunction = Arc::new(|_| unimplemented!()); let udf_agg = |inner: Expr| { - Expr::AggregateUDF(datafusion_expr::expr::AggregateUDF::new( + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( Arc::new(AggregateUDF::new( "my_agg", &Signature::exact(vec![DataType::UInt32], Volatility::Stable), @@ -918,6 +917,7 @@ mod test { &state_type, )), vec![inner], + false, None, None, )) diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index b5cf73733896..b1000f042c98 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -16,15 +16,14 @@ // under the License. use crate::simplify_expressions::{ExprSimplifier, SimplifyContext}; -use crate::utils::{ - collect_subquery_cols, conjunction, find_join_exprs, split_conjunction, -}; +use crate::utils::collect_subquery_cols; use datafusion_common::tree_node::{ RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, }; use datafusion_common::{plan_err, Result}; use datafusion_common::{Column, DFSchemaRef, DataFusionError, ScalarValue}; -use datafusion_expr::expr::Alias; +use datafusion_expr::expr::{AggregateFunctionDefinition, Alias}; +use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_physical_expr::execution_props::ExecutionProps; use std::collections::{BTreeSet, HashMap}; @@ -227,10 +226,9 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { )?; if !expr_result_map_for_count_bug.is_empty() { // has count bug - let un_matched_row = Expr::Alias(Alias::new( - Expr::Literal(ScalarValue::Boolean(Some(true))), - UN_MATCHED_ROW_INDICATOR.to_string(), - )); + let un_matched_row = + Expr::Literal(ScalarValue::Boolean(Some(true))) + .alias(UN_MATCHED_ROW_INDICATOR); // add the unmatched rows indicator to the Aggregation's group expressions missing_exprs.push(un_matched_row); } @@ -374,16 +372,25 @@ fn agg_exprs_evaluation_result_on_empty_batch( for e in agg_expr.iter() { let result_expr = e.clone().transform_up(&|expr| { let new_expr = match expr { - Expr::AggregateFunction(expr::AggregateFunction { fun, .. }) => { - if matches!(fun, datafusion_expr::AggregateFunction::Count) { - Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some(0)))) - } else { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + if matches!(fun, datafusion_expr::AggregateFunction::Count) { + Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some( + 0, + )))) + } else { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } + } + AggregateFunctionDefinition::UDF { .. } => { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } + AggregateFunctionDefinition::Name(_) => { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } } } - Expr::AggregateUDF(_) => { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) - } _ => Transformed::No(expr), }; Ok(new_expr) diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 432d7f053aef..450336376a23 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -17,14 +17,15 @@ use crate::decorrelate::PullUpCorrelatedExpr; use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, replace_qualified_name, split_conjunction}; +use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::TreeNode; -use datafusion_common::{plan_err, Column, DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; +use datafusion_expr::utils::{conjunction, split_conjunction}; use datafusion_expr::{ exists, in_subquery, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, @@ -282,12 +283,7 @@ fn build_join( false => JoinType::LeftSemi, }; let new_plan = LogicalPlanBuilder::from(left.clone()) - .join( - sub_query_alias, - join_type, - (Vec::::new(), Vec::::new()), - Some(join_filter), - )? + .join_on(sub_query_alias, join_type, Some(join_filter))? .build()?; debug!( "predicate subquery optimized:\n{}", diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index cf9a59d6b892..d9e96a9f2543 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -20,6 +20,7 @@ use std::collections::HashSet; use std::sync::Arc; use crate::{utils, OptimizerConfig, OptimizerRule}; + use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ @@ -44,84 +45,97 @@ impl EliminateCrossJoin { /// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' /// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) /// or (a.x = b.y and b.xx = 200 and a.z=c.z);' +/// 'select ... from a, b where a.x > b.y' /// For above queries, the join predicate is available in filters and they are moved to /// join nodes appropriately /// This fix helps to improve the performance of TPCH Q19. issue#78 -/// impl OptimizerRule for EliminateCrossJoin { fn try_optimize( &self, plan: &LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - match plan { + let mut possible_join_keys: Vec<(Expr, Expr)> = vec![]; + let mut all_inputs: Vec = vec![]; + let parent_predicate = match plan { LogicalPlan::Filter(filter) => { - let input = filter.input.as_ref().clone(); - - let mut possible_join_keys: Vec<(Expr, Expr)> = vec![]; - let mut all_inputs: Vec = vec![]; - let did_flat_successfully = match &input { + let input = filter.input.as_ref(); + match input { LogicalPlan::Join(Join { join_type: JoinType::Inner, .. }) - | LogicalPlan::CrossJoin(_) => try_flatten_join_inputs( - &input, - &mut possible_join_keys, - &mut all_inputs, - )?, + | LogicalPlan::CrossJoin(_) => { + if !try_flatten_join_inputs( + input, + &mut possible_join_keys, + &mut all_inputs, + )? { + return Ok(None); + } + extract_possible_join_keys( + &filter.predicate, + &mut possible_join_keys, + )?; + Some(&filter.predicate) + } _ => { return utils::optimize_children(self, plan, config); } - }; - - if !did_flat_successfully { + } + } + LogicalPlan::Join(Join { + join_type: JoinType::Inner, + .. + }) => { + if !try_flatten_join_inputs( + plan, + &mut possible_join_keys, + &mut all_inputs, + )? { return Ok(None); } + None + } + _ => return utils::optimize_children(self, plan, config), + }; - let predicate = &filter.predicate; - // join keys are handled locally - let mut all_join_keys: HashSet<(Expr, Expr)> = HashSet::new(); - - extract_possible_join_keys(predicate, &mut possible_join_keys)?; + // Join keys are handled locally: + let mut all_join_keys = HashSet::<(Expr, Expr)>::new(); + let mut left = all_inputs.remove(0); + while !all_inputs.is_empty() { + left = find_inner_join( + &left, + &mut all_inputs, + &mut possible_join_keys, + &mut all_join_keys, + )?; + } - let mut left = all_inputs.remove(0); - while !all_inputs.is_empty() { - left = find_inner_join( - &left, - &mut all_inputs, - &mut possible_join_keys, - &mut all_join_keys, - )?; - } + left = utils::optimize_children(self, &left, config)?.unwrap_or(left); - left = utils::optimize_children(self, &left, config)?.unwrap_or(left); + if plan.schema() != left.schema() { + left = LogicalPlan::Projection(Projection::new_from_schema( + Arc::new(left), + plan.schema().clone(), + )); + } - if plan.schema() != left.schema() { - left = LogicalPlan::Projection(Projection::new_from_schema( - Arc::new(left.clone()), - plan.schema().clone(), - )); - } + let Some(predicate) = parent_predicate else { + return Ok(Some(left)); + }; - // if there are no join keys then do nothing. - if all_join_keys.is_empty() { - Ok(Some(LogicalPlan::Filter(Filter::try_new( - predicate.clone(), - Arc::new(left), - )?))) - } else { - // remove join expressions from filter - match remove_join_expressions(predicate, &all_join_keys)? { - Some(filter_expr) => Ok(Some(LogicalPlan::Filter( - Filter::try_new(filter_expr, Arc::new(left))?, - ))), - _ => Ok(Some(left)), - } - } + // If there are no join keys then do nothing: + if all_join_keys.is_empty() { + Filter::try_new(predicate.clone(), Arc::new(left)) + .map(|f| Some(LogicalPlan::Filter(f))) + } else { + // Remove join expressions from filter: + match remove_join_expressions(predicate, &all_join_keys)? { + Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left)) + .map(|f| Some(LogicalPlan::Filter(f))), + _ => Ok(Some(left)), } - - _ => utils::optimize_children(self, plan, config), } } @@ -325,17 +339,16 @@ fn remove_join_expressions( #[cfg(test)] mod tests { + use super::*; + use crate::optimizer::OptimizerContext; + use crate::test::*; + use datafusion_expr::{ binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Operator::{And, Or}, }; - use crate::optimizer::OptimizerContext; - use crate::test::*; - - use super::*; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: Vec<&str>) { let rule = EliminateCrossJoin::new(); let optimized_plan = rule diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index c97906a81adf..fea14342ca77 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to replace `where false` on a plan with an empty relation. +//! Optimizer rule to replace `where false or null` on a plan with an empty relation. //! This saves time in planning and executing the query. //! Note that this rule should be applied after simplify expressions optimizer rule. use crate::optimizer::ApplyOrder; @@ -27,7 +27,7 @@ use datafusion_expr::{ use crate::{OptimizerConfig, OptimizerRule}; -/// Optimization rule that eliminate the scalar value (true/false) filter with an [LogicalPlan::EmptyRelation] +/// Optimization rule that eliminate the scalar value (true/false/null) filter with an [LogicalPlan::EmptyRelation] #[derive(Default)] pub struct EliminateFilter; @@ -46,20 +46,22 @@ impl OptimizerRule for EliminateFilter { ) -> Result> { match plan { LogicalPlan::Filter(Filter { - predicate: Expr::Literal(ScalarValue::Boolean(Some(v))), + predicate: Expr::Literal(ScalarValue::Boolean(v)), input, .. }) => { match *v { // input also can be filter, apply again - true => Ok(Some( + Some(true) => Ok(Some( self.try_optimize(input, _config)? .unwrap_or_else(|| input.as_ref().clone()), )), - false => Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: input.schema().clone(), - }))), + Some(false) | None => { + Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: input.schema().clone(), + }))) + } } } _ => Ok(None), @@ -105,6 +107,21 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + #[test] + fn filter_null() -> Result<()> { + let filter_expr = Expr::Literal(ScalarValue::Boolean(None)); + + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![sum(col("b"))])? + .filter(filter_expr)? + .build()?; + + // No aggregate / scan / limit + let expected = "EmptyRelation"; + assert_optimized_plan_equal(&plan, expected) + } + #[test] fn filter_false_nested() -> Result<()> { let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(false))); diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 00abcdcc68aa..0dbebcc8a051 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -77,7 +77,7 @@ impl OptimizerRule for EliminateJoin { mod tests { use crate::eliminate_join::EliminateJoin; use crate::test::*; - use datafusion_common::{Column, Result, ScalarValue}; + use datafusion_common::{Result, ScalarValue}; use datafusion_expr::JoinType::Inner; use datafusion_expr::{logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan}; use std::sync::Arc; @@ -89,10 +89,9 @@ mod tests { #[test] fn join_on_false() -> Result<()> { let plan = LogicalPlanBuilder::empty(false) - .join( + .join_on( LogicalPlanBuilder::empty(false).build()?, Inner, - (Vec::::new(), Vec::::new()), Some(Expr::Literal(ScalarValue::Boolean(Some(false)))), )? .build()?; @@ -104,10 +103,9 @@ mod tests { #[test] fn join_on_true() -> Result<()> { let plan = LogicalPlanBuilder::empty(false) - .join( + .join_on( LogicalPlanBuilder::empty(false).build()?, Inner, - (Vec::::new(), Vec::::new()), Some(Expr::Literal(ScalarValue::Boolean(Some(true)))), )? .build()?; diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 7844ca7909fc..4386253740aa 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -97,7 +97,7 @@ mod tests { let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs new file mode 100644 index 000000000000..5771ea2e19a2 --- /dev/null +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -0,0 +1,389 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule to replace nested unions to single union. +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::Result; +use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; +use datafusion_expr::{Distinct, LogicalPlan, Union}; +use std::sync::Arc; + +#[derive(Default)] +/// An optimization rule that replaces nested unions with a single union. +pub struct EliminateNestedUnion; + +impl EliminateNestedUnion { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for EliminateNestedUnion { + fn try_optimize( + &self, + plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + match plan { + LogicalPlan::Union(Union { inputs, schema }) => { + let inputs = inputs + .iter() + .flat_map(extract_plans_from_union) + .collect::>(); + + Ok(Some(LogicalPlan::Union(Union { + inputs, + schema: schema.clone(), + }))) + } + LogicalPlan::Distinct(Distinct::All(plan)) => match plan.as_ref() { + LogicalPlan::Union(Union { inputs, schema }) => { + let inputs = inputs + .iter() + .map(extract_plan_from_distinct) + .flat_map(extract_plans_from_union) + .collect::>(); + + Ok(Some(LogicalPlan::Distinct(Distinct::All(Arc::new( + LogicalPlan::Union(Union { + inputs, + schema: schema.clone(), + }), + ))))) + } + _ => Ok(None), + }, + _ => Ok(None), + } + } + + fn name(&self) -> &str { + "eliminate_nested_union" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } +} + +fn extract_plans_from_union(plan: &Arc) -> Vec> { + match plan.as_ref() { + LogicalPlan::Union(Union { inputs, schema }) => inputs + .iter() + .map(|plan| Arc::new(coerce_plan_expr_for_schema(plan, schema).unwrap())) + .collect::>(), + _ => vec![plan.clone()], + } +} + +fn extract_plan_from_distinct(plan: &Arc) -> &Arc { + match plan.as_ref() { + LogicalPlan::Distinct(Distinct::All(plan)) => plan, + _ => plan, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_expr::{col, logical_plan::table_scan}; + + fn schema() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, false), + ]) + } + + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq(Arc::new(EliminateNestedUnion::new()), plan, expected) + } + + #[test] + fn eliminate_nothing() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union(plan_builder.clone().build()?)? + .build()?; + + let expected = "\ + Union\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_distinct_nothing() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct(plan_builder.clone().build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_union() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .build()?; + + let expected = "\ + Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_union_with_distinct_union() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .build()?; + + let expected = "Union\ + \n Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_distinct_union() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union(plan_builder.clone().build()?)? + .union_distinct(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .union_distinct(plan_builder.clone().build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_distinct_union_with_distinct_table() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct(plan_builder.clone().distinct()?.build()?)? + .union(plan_builder.clone().distinct()?.build()?)? + .union_distinct(plan_builder.clone().build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + // We don't need to use project_with_column_index in logical optimizer, + // after LogicalPlanBuilder::union, we already have all equal expression aliases + #[test] + fn eliminate_nested_union_with_projection() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union( + plan_builder + .clone() + .project(vec![col("id").alias("table_id"), col("key"), col("value")])? + .build()?, + )? + .union( + plan_builder + .clone() + .project(vec![col("id").alias("_id"), col("key"), col("value")])? + .build()?, + )? + .build()?; + + let expected = "Union\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_distinct_union_with_projection() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct( + plan_builder + .clone() + .project(vec![col("id").alias("table_id"), col("key"), col("value")])? + .build()?, + )? + .union_distinct( + plan_builder + .clone() + .project(vec![col("id").alias("_id"), col("key"), col("value")])? + .build()?, + )? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_union_with_type_cast_projection() -> Result<()> { + let table_1 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, false), + ]), + None, + )?; + + let table_2 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let table_3 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int16, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let plan = table_1 + .union(table_2.build()?)? + .union(table_3.build()?)? + .build()?; + + let expected = "Union\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_distinct_union_with_type_cast_projection() -> Result<()> { + let table_1 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, false), + ]), + None, + )?; + + let table_2 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let table_3 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int16, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let plan = table_1 + .union_distinct(table_2.build()?)? + .union_distinct(table_3.build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1"; + assert_optimized_plan_equal(&plan, expected) + } +} diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs new file mode 100644 index 000000000000..70ee490346ff --- /dev/null +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule to eliminate one union. +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::Result; +use datafusion_expr::logical_plan::{LogicalPlan, Union}; + +use crate::optimizer::ApplyOrder; + +#[derive(Default)] +/// An optimization rule that eliminates union with one element. +pub struct EliminateOneUnion; + +impl EliminateOneUnion { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for EliminateOneUnion { + fn try_optimize( + &self, + plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + match plan { + LogicalPlan::Union(Union { inputs, .. }) if inputs.len() == 1 => { + Ok(inputs.first().map(|input| input.as_ref().clone())) + } + _ => Ok(None), + } + } + + fn name(&self) -> &str { + "eliminate_one_union" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::ToDFSchema; + use datafusion_expr::{ + expr_rewriter::coerce_plan_expr_for_schema, + logical_plan::{table_scan, Union}, + }; + use std::sync::Arc; + + fn schema() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, false), + ]) + } + + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq_with_rules( + vec![Arc::new(EliminateOneUnion::new())], + plan, + expected, + ) + } + + #[test] + fn eliminate_nothing() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union(plan_builder.clone().build()?)? + .build()?; + + let expected = "\ + Union\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_one_union() -> Result<()> { + let table_plan = coerce_plan_expr_for_schema( + &table_scan(Some("table"), &schema(), None)?.build()?, + &schema().to_dfschema()?, + )?; + let schema = table_plan.schema().clone(); + let single_union_plan = LogicalPlan::Union(Union { + inputs: vec![Arc::new(table_plan)], + schema, + }); + + let expected = "TableScan: table"; + assert_optimized_plan_equal(&single_union_plan, expected) + } +} diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index e4d57f0209a4..53c4b3702b1e 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -106,7 +106,8 @@ impl OptimizerRule for EliminateOuterJoin { schema: join.schema.clone(), null_equals_null: join.null_equals_null, }); - let new_plan = plan.with_new_inputs(&[new_join])?; + let new_plan = + plan.with_new_exprs(plan.expressions(), &[new_join])?; Ok(Some(new_plan)) } _ => Ok(None), diff --git a/datafusion/optimizer/src/eliminate_project.rs b/datafusion/optimizer/src/eliminate_project.rs deleted file mode 100644 index d3226eaa78cf..000000000000 --- a/datafusion/optimizer/src/eliminate_project.rs +++ /dev/null @@ -1,94 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{DFSchemaRef, Result}; -use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::{Expr, Projection}; - -/// Optimization rule that eliminate unnecessary [LogicalPlan::Projection]. -#[derive(Default)] -pub struct EliminateProjection; - -impl EliminateProjection { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl OptimizerRule for EliminateProjection { - fn try_optimize( - &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - match plan { - LogicalPlan::Projection(projection) => { - let child_plan = projection.input.as_ref(); - match child_plan { - LogicalPlan::Union(_) - | LogicalPlan::Filter(_) - | LogicalPlan::TableScan(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Sort(_) => { - if can_eliminate(projection, child_plan.schema()) { - Ok(Some(child_plan.clone())) - } else { - Ok(None) - } - } - _ => { - if plan.schema() == child_plan.schema() { - Ok(Some(child_plan.clone())) - } else { - Ok(None) - } - } - } - } - _ => Ok(None), - } - } - - fn name(&self) -> &str { - "eliminate_projection" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } -} - -pub(crate) fn can_eliminate(projection: &Projection, schema: &DFSchemaRef) -> bool { - if projection.expr.len() != schema.fields().len() { - return false; - } - for (i, e) in projection.expr.iter().enumerate() { - match e { - Expr::Column(c) => { - let d = schema.fields().get(i).unwrap(); - if c != &d.qualified_column() && c != &d.unqualified_column() { - return false; - } - } - _ => return false, - } - } - true -} diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index e328eeeb00a1..24664d57c38d 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -17,11 +17,10 @@ //! [`ExtractEquijoinPredicate`] rule that extracts equijoin predicates use crate::optimizer::ApplyOrder; -use crate::utils::split_conjunction; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::DFSchema; use datafusion_common::Result; -use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; +use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair, split_conjunction}; use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator}; use std::sync::Arc; @@ -161,7 +160,6 @@ mod tests { use super::*; use crate::test::*; use arrow::datatypes::DataType; - use datafusion_common::Column; use datafusion_expr::{ col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType, }; @@ -182,12 +180,7 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; let plan = LogicalPlanBuilder::from(t1) - .join( - t2, - JoinType::Left, - (Vec::::new(), Vec::::new()), - Some(col("t1.a").eq(col("t2.a"))), - )? + .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))? .build()?; let expected = "Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ @@ -202,10 +195,9 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; let plan = LogicalPlanBuilder::from(t1) - .join( + .join_on( t2, JoinType::Left, - (Vec::::new(), Vec::::new()), Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))), )? .build()?; @@ -222,10 +214,9 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; let plan = LogicalPlanBuilder::from(t1) - .join( + .join_on( t2, JoinType::Left, - (Vec::::new(), Vec::::new()), Some( (col("t1.a") + lit(10i64)) .gt_eq(col("t2.a") * lit(2u32)) @@ -273,10 +264,9 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; let plan = LogicalPlanBuilder::from(t1) - .join( + .join_on( t2, JoinType::Left, - (Vec::::new(), Vec::::new()), Some( col("t1.c") .eq(col("t2.c")) @@ -301,10 +291,9 @@ mod tests { let t3 = test_table_scan_with_name("t3")?; let input = LogicalPlanBuilder::from(t2) - .join( + .join_on( t3, JoinType::Left, - (Vec::::new(), Vec::::new()), Some( col("t2.a") .eq(col("t3.a")) @@ -313,10 +302,9 @@ mod tests { )? .build()?; let plan = LogicalPlanBuilder::from(t1) - .join( + .join_on( input, JoinType::Left, - (Vec::::new(), Vec::::new()), Some( col("t1.a") .eq(col("t2.a")) @@ -340,10 +328,9 @@ mod tests { let t3 = test_table_scan_with_name("t3")?; let input = LogicalPlanBuilder::from(t2) - .join( + .join_on( t3, JoinType::Left, - (Vec::::new(), Vec::::new()), Some( col("t2.a") .eq(col("t3.a")) @@ -352,10 +339,9 @@ mod tests { )? .build()?; let plan = LogicalPlanBuilder::from(t1) - .join( + .join_on( input, JoinType::Left, - (Vec::::new(), Vec::::new()), Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))), )? .build()?; @@ -383,12 +369,7 @@ mod tests { ) .alias("t1.a + 1 = t2.a + 2"); let plan = LogicalPlanBuilder::from(t1) - .join( - t2, - JoinType::Left, - (Vec::::new(), Vec::::new()), - Some(filter), - )? + .join_on(t2, JoinType::Left, Some(filter))? .build()?; let expected = "Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 1d12ca7e3950..b54facc5d682 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -24,11 +24,12 @@ pub mod eliminate_duplicated_expr; pub mod eliminate_filter; pub mod eliminate_join; pub mod eliminate_limit; +pub mod eliminate_nested_union; +pub mod eliminate_one_union; pub mod eliminate_outer_join; -pub mod eliminate_project; pub mod extract_equijoin_predicate; pub mod filter_null_join_keys; -pub mod merge_projection; +pub mod optimize_projections; pub mod optimizer; pub mod propagate_empty_relation; pub mod push_down_filter; diff --git a/datafusion/optimizer/src/merge_projection.rs b/datafusion/optimizer/src/merge_projection.rs deleted file mode 100644 index 408055b8e7d4..000000000000 --- a/datafusion/optimizer/src/merge_projection.rs +++ /dev/null @@ -1,168 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::optimizer::ApplyOrder; -use datafusion_common::Result; -use datafusion_expr::{Expr, LogicalPlan, Projection}; -use std::collections::HashMap; - -use crate::push_down_filter::replace_cols_by_name; -use crate::{OptimizerConfig, OptimizerRule}; - -/// Optimization rule that merge [LogicalPlan::Projection]. -#[derive(Default)] -pub struct MergeProjection; - -impl MergeProjection { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl OptimizerRule for MergeProjection { - fn try_optimize( - &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - match plan { - LogicalPlan::Projection(parent_projection) => { - match parent_projection.input.as_ref() { - LogicalPlan::Projection(child_projection) => { - let new_plan = - merge_projection(parent_projection, child_projection)?; - Ok(Some( - self.try_optimize(&new_plan, _config)?.unwrap_or(new_plan), - )) - } - _ => Ok(None), - } - } - _ => Ok(None), - } - } - - fn name(&self) -> &str { - "merge_projection" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } -} - -pub(super) fn merge_projection( - parent_projection: &Projection, - child_projection: &Projection, -) -> Result { - let replace_map = collect_projection_expr(child_projection); - let new_exprs = parent_projection - .expr - .iter() - .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) - .enumerate() - .map(|(i, e)| match e { - Ok(e) => { - let parent_expr = parent_projection.schema.fields()[i].qualified_name(); - e.alias_if_changed(parent_expr) - } - Err(e) => Err(e), - }) - .collect::>>()?; - let new_plan = LogicalPlan::Projection(Projection::try_new_with_schema( - new_exprs, - child_projection.input.clone(), - parent_projection.schema.clone(), - )?); - Ok(new_plan) -} - -pub fn collect_projection_expr(projection: &Projection) -> HashMap { - projection - .schema - .fields() - .iter() - .enumerate() - .flat_map(|(i, field)| { - // strip alias - let expr = projection.expr[i].clone().unalias(); - // Convert both qualified and unqualified fields - [ - (field.name().clone(), expr.clone()), - (field.qualified_name(), expr), - ] - }) - .collect::>() -} - -#[cfg(test)] -mod tests { - use crate::merge_projection::MergeProjection; - use datafusion_common::Result; - use datafusion_expr::{ - binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan, - Operator, - }; - use std::sync::Arc; - - use crate::test::*; - - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(MergeProjection::new()), plan, expected) - } - - #[test] - fn merge_two_projection() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a")])? - .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? - .build()?; - - let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) - } - - #[test] - fn merge_three_projection() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a"), col("b")])? - .project(vec![col("a")])? - .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? - .build()?; - - let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) - } - - #[test] - fn merge_alias() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a")])? - .project(vec![col("a").alias("alias")])? - .build()?; - - let expected = "Projection: test.a AS alias\ - \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) - } -} diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs new file mode 100644 index 000000000000..1d4eda0bd23e --- /dev/null +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -0,0 +1,1254 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule to prune unnecessary columns from intermediate schemas +//! inside the [`LogicalPlan`]. This rule: +//! - Removes unnecessary columns that do not appear at the output and/or are +//! not used during any computation step. +//! - Adds projections to decrease table column size before operators that +//! benefit from a smaller memory footprint at its input. +//! - Removes unnecessary [`LogicalPlan::Projection`]s from the [`LogicalPlan`]. + +use std::collections::HashSet; +use std::sync::Arc; + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use arrow::datatypes::SchemaRef; +use datafusion_common::{ + get_required_group_by_exprs_indices, Column, DFSchema, DFSchemaRef, JoinType, Result, +}; +use datafusion_expr::expr::{Alias, ScalarFunction, ScalarFunctionDefinition}; +use datafusion_expr::{ + logical_plan::LogicalPlan, projection_schema, Aggregate, BinaryExpr, Cast, Distinct, + Expr, GroupingSet, Projection, TableScan, Window, +}; + +use hashbrown::HashMap; +use itertools::{izip, Itertools}; + +/// A rule for optimizing logical plans by removing unused columns/fields. +/// +/// `OptimizeProjections` is an optimizer rule that identifies and eliminates +/// columns from a logical plan that are not used by downstream operations. +/// This can improve query performance and reduce unnecessary data processing. +/// +/// The rule analyzes the input logical plan, determines the necessary column +/// indices, and then removes any unnecessary columns. It also removes any +/// unnecessary projections from the plan tree. +#[derive(Default)] +pub struct OptimizeProjections {} + +impl OptimizeProjections { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for OptimizeProjections { + fn try_optimize( + &self, + plan: &LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + // All output fields are necessary: + let indices = (0..plan.schema().fields().len()).collect::>(); + optimize_projections(plan, config, &indices) + } + + fn name(&self) -> &str { + "optimize_projections" + } + + fn apply_order(&self) -> Option { + None + } +} + +/// Removes unnecessary columns (e.g. columns that do not appear in the output +/// schema and/or are not used during any computation step such as expression +/// evaluation) from the logical plan and its inputs. +/// +/// # Parameters +/// +/// - `plan`: A reference to the input `LogicalPlan` to optimize. +/// - `config`: A reference to the optimizer configuration. +/// - `indices`: A slice of column indices that represent the necessary column +/// indices for downstream operations. +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` without unnecessary +/// columns. +/// - `Ok(None)`: Signal that the given logical plan did not require any change. +/// - `Err(error)`: An error occured during the optimization process. +fn optimize_projections( + plan: &LogicalPlan, + config: &dyn OptimizerConfig, + indices: &[usize], +) -> Result> { + // `child_required_indices` stores + // - indices of the columns required for each child + // - a flag indicating whether putting a projection above children is beneficial for the parent. + // As an example LogicalPlan::Filter benefits from small tables. Hence for filter child this flag would be `true`. + let child_required_indices: Vec<(Vec, bool)> = match plan { + LogicalPlan::Sort(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Repartition(_) + | LogicalPlan::Unnest(_) + | LogicalPlan::Union(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Distinct(Distinct::On(_)) => { + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. All these + // operators benefit from "small" inputs, so the projection_beneficial + // flag is `true`. + let exprs = plan.expressions(); + plan.inputs() + .into_iter() + .map(|input| { + get_all_required_indices(indices, input, exprs.iter()) + .map(|idxs| (idxs, true)) + }) + .collect::>()? + } + LogicalPlan::Limit(_) | LogicalPlan::Prepare(_) => { + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. These operators + // do not benefit from "small" inputs, so the projection_beneficial + // flag is `false`. + let exprs = plan.expressions(); + plan.inputs() + .into_iter() + .map(|input| { + get_all_required_indices(indices, input, exprs.iter()) + .map(|idxs| (idxs, false)) + }) + .collect::>()? + } + LogicalPlan::Copy(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::Distinct(Distinct::All(_)) => { + // These plans require all their fields, and their children should + // be treated as final plans -- otherwise, we may have schema a + // mismatch. + // TODO: For some subquery variants (e.g. a subquery arising from an + // EXISTS expression), we may not need to require all indices. + plan.inputs() + .iter() + .map(|input| ((0..input.schema().fields().len()).collect_vec(), false)) + .collect::>() + } + LogicalPlan::EmptyRelation(_) + | LogicalPlan::Statement(_) + | LogicalPlan::Values(_) + | LogicalPlan::Extension(_) + | LogicalPlan::DescribeTable(_) => { + // These operators have no inputs, so stop the optimization process. + // TODO: Add support for `LogicalPlan::Extension`. + return Ok(None); + } + LogicalPlan::Projection(proj) => { + return if let Some(proj) = merge_consecutive_projections(proj)? { + Ok(Some( + rewrite_projection_given_requirements(&proj, config, indices)? + // Even if we cannot optimize the projection, merge if possible: + .unwrap_or_else(|| LogicalPlan::Projection(proj)), + )) + } else { + rewrite_projection_given_requirements(proj, config, indices) + }; + } + LogicalPlan::Aggregate(aggregate) => { + // Split parent requirements to GROUP BY and aggregate sections: + let n_group_exprs = aggregate.group_expr_len()?; + let (group_by_reqs, mut aggregate_reqs): (Vec, Vec) = + indices.iter().partition(|&&idx| idx < n_group_exprs); + // Offset aggregate indices so that they point to valid indices at + // `aggregate.aggr_expr`: + for idx in aggregate_reqs.iter_mut() { + *idx -= n_group_exprs; + } + + // Get absolutely necessary GROUP BY fields: + let group_by_expr_existing = aggregate + .group_expr + .iter() + .map(|group_by_expr| group_by_expr.display_name()) + .collect::>>()?; + let new_group_bys = if let Some(simplest_groupby_indices) = + get_required_group_by_exprs_indices( + aggregate.input.schema(), + &group_by_expr_existing, + ) { + // Some of the fields in the GROUP BY may be required by the + // parent even if these fields are unnecessary in terms of + // functional dependency. + let required_indices = + merge_slices(&simplest_groupby_indices, &group_by_reqs); + get_at_indices(&aggregate.group_expr, &required_indices) + } else { + aggregate.group_expr.clone() + }; + + // Only use the absolutely necessary aggregate expressions required + // by the parent: + let mut new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs); + let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); + let schema = aggregate.input.schema(); + let necessary_indices = indices_referred_by_exprs(schema, all_exprs_iter)?; + + let aggregate_input = if let Some(input) = + optimize_projections(&aggregate.input, config, &necessary_indices)? + { + input + } else { + aggregate.input.as_ref().clone() + }; + + // Simplify the input of the aggregation by adding a projection so + // that its input only contains absolutely necessary columns for + // the aggregate expressions. Note that necessary_indices refer to + // fields in `aggregate.input.schema()`. + let necessary_exprs = get_required_exprs(schema, &necessary_indices); + let (aggregate_input, _) = + add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)?; + + // Aggregations always need at least one aggregate expression. + // With a nested count, we don't require any column as input, but + // still need to create a correct aggregate, which may be optimized + // out later. As an example, consider the following query: + // + // SELECT COUNT(*) FROM (SELECT COUNT(*) FROM [...]) + // + // which always returns 1. + if new_aggr_expr.is_empty() + && new_group_bys.is_empty() + && !aggregate.aggr_expr.is_empty() + { + new_aggr_expr = vec![aggregate.aggr_expr[0].clone()]; + } + + // Create a new aggregate plan with the updated input and only the + // absolutely necessary fields: + return Aggregate::try_new( + Arc::new(aggregate_input), + new_group_bys, + new_aggr_expr, + ) + .map(|aggregate| Some(LogicalPlan::Aggregate(aggregate))); + } + LogicalPlan::Window(window) => { + // Split parent requirements to child and window expression sections: + let n_input_fields = window.input.schema().fields().len(); + let (child_reqs, mut window_reqs): (Vec, Vec) = + indices.iter().partition(|&&idx| idx < n_input_fields); + // Offset window expression indices so that they point to valid + // indices at `window.window_expr`: + for idx in window_reqs.iter_mut() { + *idx -= n_input_fields; + } + + // Only use window expressions that are absolutely necessary according + // to parent requirements: + let new_window_expr = get_at_indices(&window.window_expr, &window_reqs); + + // Get all the required column indices at the input, either by the + // parent or window expression requirements. + let required_indices = get_all_required_indices( + &child_reqs, + &window.input, + new_window_expr.iter(), + )?; + let window_child = if let Some(new_window_child) = + optimize_projections(&window.input, config, &required_indices)? + { + new_window_child + } else { + window.input.as_ref().clone() + }; + + return if new_window_expr.is_empty() { + // When no window expression is necessary, use the input directly: + Ok(Some(window_child)) + } else { + // Calculate required expressions at the input of the window. + // Please note that we use `old_child`, because `required_indices` + // refers to `old_child`. + let required_exprs = + get_required_exprs(window.input.schema(), &required_indices); + let (window_child, _) = + add_projection_on_top_if_helpful(window_child, required_exprs)?; + Window::try_new(new_window_expr, Arc::new(window_child)) + .map(|window| Some(LogicalPlan::Window(window))) + }; + } + LogicalPlan::Join(join) => { + let left_len = join.left.schema().fields().len(); + let (left_req_indices, right_req_indices) = + split_join_requirements(left_len, indices, &join.join_type); + let exprs = plan.expressions(); + let left_indices = + get_all_required_indices(&left_req_indices, &join.left, exprs.iter())?; + let right_indices = + get_all_required_indices(&right_req_indices, &join.right, exprs.iter())?; + // Joins benefit from "small" input tables (lower memory usage). + // Therefore, each child benefits from projection: + vec![(left_indices, true), (right_indices, true)] + } + LogicalPlan::CrossJoin(cross_join) => { + let left_len = cross_join.left.schema().fields().len(); + let (left_child_indices, right_child_indices) = + split_join_requirements(left_len, indices, &JoinType::Inner); + // Joins benefit from "small" input tables (lower memory usage). + // Therefore, each child benefits from projection: + vec![(left_child_indices, true), (right_child_indices, true)] + } + LogicalPlan::TableScan(table_scan) => { + let schema = table_scan.source.schema(); + // Get indices referred to in the original (schema with all fields) + // given projected indices. + let projection = with_indices(&table_scan.projection, schema, |map| { + indices.iter().map(|&idx| map[idx]).collect() + }); + + return TableScan::try_new( + table_scan.table_name.clone(), + table_scan.source.clone(), + Some(projection), + table_scan.filters.clone(), + table_scan.fetch, + ) + .map(|table| Some(LogicalPlan::TableScan(table))); + } + }; + + let new_inputs = izip!(child_required_indices, plan.inputs().into_iter()) + .map(|((required_indices, projection_beneficial), child)| { + let (input, is_changed) = if let Some(new_input) = + optimize_projections(child, config, &required_indices)? + { + (new_input, true) + } else { + (child.clone(), false) + }; + let project_exprs = get_required_exprs(child.schema(), &required_indices); + let (input, proj_added) = if projection_beneficial { + add_projection_on_top_if_helpful(input, project_exprs)? + } else { + (input, false) + }; + Ok((is_changed || proj_added).then_some(input)) + }) + .collect::>>()?; + if new_inputs.iter().all(|child| child.is_none()) { + // All children are the same in this case, no need to change the plan: + Ok(None) + } else { + // At least one of the children is changed: + let new_inputs = izip!(new_inputs, plan.inputs()) + // If new_input is `None`, this means child is not changed, so use + // `old_child` during construction: + .map(|(new_input, old_child)| new_input.unwrap_or_else(|| old_child.clone())) + .collect::>(); + plan.with_new_exprs(plan.expressions(), &new_inputs) + .map(Some) + } +} + +/// This function applies the given function `f` to the projection indices +/// `proj_indices` if they exist. Otherwise, applies `f` to a default set +/// of indices according to `schema`. +fn with_indices( + proj_indices: &Option>, + schema: SchemaRef, + mut f: F, +) -> Vec +where + F: FnMut(&[usize]) -> Vec, +{ + match proj_indices { + Some(indices) => f(indices.as_slice()), + None => { + let range: Vec = (0..schema.fields.len()).collect(); + f(range.as_slice()) + } + } +} + +/// Merges consecutive projections. +/// +/// Given a projection `proj`, this function attempts to merge it with a previous +/// projection if it exists and if merging is beneficial. Merging is considered +/// beneficial when expressions in the current projection are non-trivial and +/// appear more than once in its input fields. This can act as a caching mechanism +/// for non-trivial computations. +/// +/// # Parameters +/// +/// * `proj` - A reference to the `Projection` to be merged. +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Projection))`: Merge was beneficial and successful. Contains the +/// merged projection. +/// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place). +/// - `Err(error)`: An error occured during the function call. +fn merge_consecutive_projections(proj: &Projection) -> Result> { + let LogicalPlan::Projection(prev_projection) = proj.input.as_ref() else { + return Ok(None); + }; + + // Count usages (referrals) of each projection expression in its input fields: + let mut column_referral_map = HashMap::::new(); + for columns in proj.expr.iter().flat_map(|expr| expr.to_columns()) { + for col in columns.into_iter() { + *column_referral_map.entry(col.clone()).or_default() += 1; + } + } + + // If an expression is non-trivial and appears more than once, consecutive + // projections will benefit from a compute-once approach. For details, see: + // https://github.com/apache/arrow-datafusion/issues/8296 + if column_referral_map.into_iter().any(|(col, usage)| { + usage > 1 + && !is_expr_trivial( + &prev_projection.expr + [prev_projection.schema.index_of_column(&col).unwrap()], + ) + }) { + return Ok(None); + } + + // If all the expression of the top projection can be rewritten, do so and + // create a new projection: + let new_exprs = proj + .expr + .iter() + .map(|expr| rewrite_expr(expr, prev_projection)) + .collect::>>>()?; + if let Some(new_exprs) = new_exprs { + let new_exprs = new_exprs + .into_iter() + .zip(proj.expr.iter()) + .map(|(new_expr, old_expr)| { + new_expr.alias_if_changed(old_expr.name_for_alias()?) + }) + .collect::>>()?; + Projection::try_new(new_exprs, prev_projection.input.clone()).map(Some) + } else { + Ok(None) + } +} + +/// Trim the given expression by removing any unnecessary layers of aliasing. +/// If the expression is an alias, the function returns the underlying expression. +/// Otherwise, it returns the given expression as is. +/// +/// Without trimming, we can end up with unnecessary indirections inside expressions +/// during projection merges. +/// +/// Consider: +/// +/// ```text +/// Projection(a1 + b1 as sum1) +/// --Projection(a as a1, b as b1) +/// ----Source(a, b) +/// ``` +/// +/// After merge, we want to produce: +/// +/// ```text +/// Projection(a + b as sum1) +/// --Source(a, b) +/// ``` +/// +/// Without trimming, we would end up with: +/// +/// ```text +/// Projection((a as a1 + b as b1) as sum1) +/// --Source(a, b) +/// ``` +fn trim_expr(expr: Expr) -> Expr { + match expr { + Expr::Alias(alias) => trim_expr(*alias.expr), + _ => expr, + } +} + +// Check whether `expr` is trivial; i.e. it doesn't imply any computation. +fn is_expr_trivial(expr: &Expr) -> bool { + matches!(expr, Expr::Column(_) | Expr::Literal(_)) +} + +// Exit early when there is no rewrite to do. +macro_rules! rewrite_expr_with_check { + ($expr:expr, $input:expr) => { + if let Some(value) = rewrite_expr($expr, $input)? { + value + } else { + return Ok(None); + } + }; +} + +/// Rewrites a projection expression using the projection before it (i.e. its input) +/// This is a subroutine to the `merge_consecutive_projections` function. +/// +/// # Parameters +/// +/// * `expr` - A reference to the expression to rewrite. +/// * `input` - A reference to the input of the projection expression (itself +/// a projection). +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result. +/// - `Ok(None)`: Signals that `expr` can not be rewritten. +/// - `Err(error)`: An error occured during the function call. +fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { + let result = match expr { + Expr::Column(col) => { + // Find index of column: + let idx = input.schema.index_of_column(col)?; + input.expr[idx].clone() + } + Expr::BinaryExpr(binary) => Expr::BinaryExpr(BinaryExpr::new( + Box::new(trim_expr(rewrite_expr_with_check!(&binary.left, input))), + binary.op, + Box::new(trim_expr(rewrite_expr_with_check!(&binary.right, input))), + )), + Expr::Alias(alias) => Expr::Alias(Alias::new( + trim_expr(rewrite_expr_with_check!(&alias.expr, input)), + alias.relation.clone(), + alias.name.clone(), + )), + Expr::Literal(_) => expr.clone(), + Expr::Cast(cast) => { + let new_expr = rewrite_expr_with_check!(&cast.expr, input); + Expr::Cast(Cast::new(Box::new(new_expr), cast.data_type.clone())) + } + Expr::ScalarFunction(scalar_fn) => { + // TODO: Support UDFs. + let ScalarFunctionDefinition::BuiltIn(fun) = scalar_fn.func_def else { + return Ok(None); + }; + return Ok(scalar_fn + .args + .iter() + .map(|expr| rewrite_expr(expr, input)) + .collect::>>()? + .map(|new_args| { + Expr::ScalarFunction(ScalarFunction::new(fun, new_args)) + })); + } + // Unsupported type for consecutive projection merge analysis. + _ => return Ok(None), + }; + Ok(Some(result)) +} + +/// Retrieves a set of outer-referenced columns by the given expression, `expr`. +/// Note that the `Expr::to_columns()` function doesn't return these columns. +/// +/// # Parameters +/// +/// * `expr` - The expression to analyze for outer-referenced columns. +/// +/// # Returns +/// +/// returns a `HashSet` containing all outer-referenced columns. +fn outer_columns(expr: &Expr) -> HashSet { + let mut columns = HashSet::new(); + outer_columns_helper(expr, &mut columns); + columns +} + +/// A recursive subroutine that accumulates outer-referenced columns by the +/// given expression, `expr`. +/// +/// # Parameters +/// +/// * `expr` - The expression to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. +fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) { + match expr { + Expr::OuterReferenceColumn(_, col) => { + columns.insert(col.clone()); + } + Expr::BinaryExpr(binary_expr) => { + outer_columns_helper(&binary_expr.left, columns); + outer_columns_helper(&binary_expr.right, columns); + } + Expr::ScalarSubquery(subquery) => { + let exprs = subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns); + } + Expr::Exists(exists) => { + let exprs = exists.subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns); + } + Expr::Alias(alias) => outer_columns_helper(&alias.expr, columns), + Expr::InSubquery(insubquery) => { + let exprs = insubquery.subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns); + } + Expr::Cast(cast) => outer_columns_helper(&cast.expr, columns), + Expr::Sort(sort) => outer_columns_helper(&sort.expr, columns), + Expr::AggregateFunction(aggregate_fn) => { + outer_columns_helper_multi(aggregate_fn.args.iter(), columns); + if let Some(filter) = aggregate_fn.filter.as_ref() { + outer_columns_helper(filter, columns); + } + if let Some(obs) = aggregate_fn.order_by.as_ref() { + outer_columns_helper_multi(obs.iter(), columns); + } + } + Expr::WindowFunction(window_fn) => { + outer_columns_helper_multi(window_fn.args.iter(), columns); + outer_columns_helper_multi(window_fn.order_by.iter(), columns); + outer_columns_helper_multi(window_fn.partition_by.iter(), columns); + } + Expr::GroupingSet(groupingset) => match groupingset { + GroupingSet::GroupingSets(multi_exprs) => { + multi_exprs + .iter() + .for_each(|e| outer_columns_helper_multi(e.iter(), columns)); + } + GroupingSet::Cube(exprs) | GroupingSet::Rollup(exprs) => { + outer_columns_helper_multi(exprs.iter(), columns); + } + }, + Expr::ScalarFunction(scalar_fn) => { + outer_columns_helper_multi(scalar_fn.args.iter(), columns); + } + Expr::Like(like) => { + outer_columns_helper(&like.expr, columns); + outer_columns_helper(&like.pattern, columns); + } + Expr::InList(in_list) => { + outer_columns_helper(&in_list.expr, columns); + outer_columns_helper_multi(in_list.list.iter(), columns); + } + Expr::Case(case) => { + let when_then_exprs = case + .when_then_expr + .iter() + .flat_map(|(first, second)| [first.as_ref(), second.as_ref()]); + outer_columns_helper_multi(when_then_exprs, columns); + if let Some(expr) = case.expr.as_ref() { + outer_columns_helper(expr, columns); + } + if let Some(expr) = case.else_expr.as_ref() { + outer_columns_helper(expr, columns); + } + } + Expr::SimilarTo(similar_to) => { + outer_columns_helper(&similar_to.expr, columns); + outer_columns_helper(&similar_to.pattern, columns); + } + Expr::TryCast(try_cast) => outer_columns_helper(&try_cast.expr, columns), + Expr::GetIndexedField(index) => outer_columns_helper(&index.expr, columns), + Expr::Between(between) => { + outer_columns_helper(&between.expr, columns); + outer_columns_helper(&between.low, columns); + outer_columns_helper(&between.high, columns); + } + Expr::Not(expr) + | Expr::IsNotFalse(expr) + | Expr::IsFalse(expr) + | Expr::IsTrue(expr) + | Expr::IsNotTrue(expr) + | Expr::IsUnknown(expr) + | Expr::IsNotUnknown(expr) + | Expr::IsNotNull(expr) + | Expr::IsNull(expr) + | Expr::Negative(expr) => outer_columns_helper(expr, columns), + Expr::Column(_) + | Expr::Literal(_) + | Expr::Wildcard { .. } + | Expr::ScalarVariable { .. } + | Expr::Placeholder(_) => (), + } +} + +/// A recursive subroutine that accumulates outer-referenced columns by the +/// given expressions (`exprs`). +/// +/// # Parameters +/// +/// * `exprs` - The expressions to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. +fn outer_columns_helper_multi<'a>( + exprs: impl Iterator, + columns: &mut HashSet, +) { + exprs.for_each(|e| outer_columns_helper(e, columns)); +} + +/// Generates the required expressions (columns) that reside at `indices` of +/// the given `input_schema`. +/// +/// # Arguments +/// +/// * `input_schema` - A reference to the input schema. +/// * `indices` - A slice of `usize` indices specifying required columns. +/// +/// # Returns +/// +/// A vector of `Expr::Column` expressions residing at `indices` of the `input_schema`. +fn get_required_exprs(input_schema: &Arc, indices: &[usize]) -> Vec { + let fields = input_schema.fields(); + indices + .iter() + .map(|&idx| Expr::Column(fields[idx].qualified_column())) + .collect() +} + +/// Get indices of the fields referred to by any expression in `exprs` within +/// the given schema (`input_schema`). +/// +/// # Arguments +/// +/// * `input_schema`: The input schema to analyze for index requirements. +/// * `exprs`: An iterator of expressions for which we want to find necessary +/// field indices. +/// +/// # Returns +/// +/// A [`Result`] object containing the indices of all required fields in +/// `input_schema` to calculate all `exprs` successfully. +fn indices_referred_by_exprs<'a>( + input_schema: &DFSchemaRef, + exprs: impl Iterator, +) -> Result> { + let indices = exprs + .map(|expr| indices_referred_by_expr(input_schema, expr)) + .collect::>>()?; + Ok(indices + .into_iter() + .flatten() + // Make sure no duplicate entries exist and indices are ordered: + .sorted() + .dedup() + .collect()) +} + +/// Get indices of the fields referred to by the given expression `expr` within +/// the given schema (`input_schema`). +/// +/// # Parameters +/// +/// * `input_schema`: The input schema to analyze for index requirements. +/// * `expr`: An expression for which we want to find necessary field indices. +/// +/// # Returns +/// +/// A [`Result`] object containing the indices of all required fields in +/// `input_schema` to calculate `expr` successfully. +fn indices_referred_by_expr( + input_schema: &DFSchemaRef, + expr: &Expr, +) -> Result> { + let mut cols = expr.to_columns()?; + // Get outer-referenced columns: + cols.extend(outer_columns(expr)); + Ok(cols + .iter() + .flat_map(|col| input_schema.index_of_column(col)) + .collect()) +} + +/// Gets all required indices for the input; i.e. those required by the parent +/// and those referred to by `exprs`. +/// +/// # Parameters +/// +/// * `parent_required_indices` - A slice of indices required by the parent plan. +/// * `input` - The input logical plan to analyze for index requirements. +/// * `exprs` - An iterator of expressions used to determine required indices. +/// +/// # Returns +/// +/// A `Result` containing a vector of `usize` indices containing all the required +/// indices. +fn get_all_required_indices<'a>( + parent_required_indices: &[usize], + input: &LogicalPlan, + exprs: impl Iterator, +) -> Result> { + indices_referred_by_exprs(input.schema(), exprs) + .map(|indices| merge_slices(parent_required_indices, &indices)) +} + +/// Retrieves the expressions at specified indices within the given slice. Ignores +/// any invalid indices. +/// +/// # Parameters +/// +/// * `exprs` - A slice of expressions to index into. +/// * `indices` - A slice of indices specifying the positions of expressions sought. +/// +/// # Returns +/// +/// A vector of expressions corresponding to specified indices. +fn get_at_indices(exprs: &[Expr], indices: &[usize]) -> Vec { + indices + .iter() + // Indices may point to further places than `exprs` len. + .filter_map(|&idx| exprs.get(idx).cloned()) + .collect() +} + +/// Merges two slices into a single vector with sorted (ascending) and +/// deduplicated elements. For example, merging `[3, 2, 4]` and `[3, 6, 1]` +/// will produce `[1, 2, 3, 6]`. +fn merge_slices(left: &[T], right: &[T]) -> Vec { + // Make sure to sort before deduping, which removes the duplicates: + left.iter() + .cloned() + .chain(right.iter().cloned()) + .sorted() + .dedup() + .collect() +} + +/// Splits requirement indices for a join into left and right children based on +/// the join type. +/// +/// This function takes the length of the left child, a slice of requirement +/// indices, and the type of join (e.g. `INNER`, `LEFT`, `RIGHT`) as arguments. +/// Depending on the join type, it divides the requirement indices into those +/// that apply to the left child and those that apply to the right child. +/// +/// - For `INNER`, `LEFT`, `RIGHT` and `FULL` joins, the requirements are split +/// between left and right children. The right child indices are adjusted to +/// point to valid positions within the right child by subtracting the length +/// of the left child. +/// +/// - For `LEFT ANTI`, `LEFT SEMI`, `RIGHT SEMI` and `RIGHT ANTI` joins, all +/// requirements are re-routed to either the left child or the right child +/// directly, depending on the join type. +/// +/// # Parameters +/// +/// * `left_len` - The length of the left child. +/// * `indices` - A slice of requirement indices. +/// * `join_type` - The type of join (e.g. `INNER`, `LEFT`, `RIGHT`). +/// +/// # Returns +/// +/// A tuple containing two vectors of `usize` indices: The first vector represents +/// the requirements for the left child, and the second vector represents the +/// requirements for the right child. The indices are appropriately split and +/// adjusted based on the join type. +fn split_join_requirements( + left_len: usize, + indices: &[usize], + join_type: &JoinType, +) -> (Vec, Vec) { + match join_type { + // In these cases requirements are split between left/right children: + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + let (left_reqs, mut right_reqs): (Vec, Vec) = + indices.iter().partition(|&&idx| idx < left_len); + // Decrease right side indices by `left_len` so that they point to valid + // positions within the right child: + for idx in right_reqs.iter_mut() { + *idx -= left_len; + } + (left_reqs, right_reqs) + } + // All requirements can be re-routed to left child directly. + JoinType::LeftAnti | JoinType::LeftSemi => (indices.to_vec(), vec![]), + // All requirements can be re-routed to right side directly. + // No need to change index, join schema is right child schema. + JoinType::RightSemi | JoinType::RightAnti => (vec![], indices.to_vec()), + } +} + +/// Adds a projection on top of a logical plan if doing so reduces the number +/// of columns for the parent operator. +/// +/// This function takes a `LogicalPlan` and a list of projection expressions. +/// If the projection is beneficial (it reduces the number of columns in the +/// plan) a new `LogicalPlan` with the projection is created and returned, along +/// with a `true` flag. If the projection doesn't reduce the number of columns, +/// the original plan is returned with a `false` flag. +/// +/// # Parameters +/// +/// * `plan` - The input `LogicalPlan` to potentially add a projection to. +/// * `project_exprs` - A list of expressions for the projection. +/// +/// # Returns +/// +/// A `Result` containing a tuple with two values: The resulting `LogicalPlan` +/// (with or without the added projection) and a `bool` flag indicating if a +/// projection was added (`true`) or not (`false`). +fn add_projection_on_top_if_helpful( + plan: LogicalPlan, + project_exprs: Vec, +) -> Result<(LogicalPlan, bool)> { + // Make sure projection decreases the number of columns, otherwise it is unnecessary. + if project_exprs.len() >= plan.schema().fields().len() { + Ok((plan, false)) + } else { + Projection::try_new(project_exprs, Arc::new(plan)) + .map(|proj| (LogicalPlan::Projection(proj), true)) + } +} + +/// Rewrite the given projection according to the fields required by its +/// ancestors. +/// +/// # Parameters +/// +/// * `proj` - A reference to the original projection to rewrite. +/// * `config` - A reference to the optimizer configuration. +/// * `indices` - A slice of indices representing the columns required by the +/// ancestors of the given projection. +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(LogicalPlan))`: Contains the rewritten projection +/// - `Ok(None)`: No rewrite necessary. +/// - `Err(error)`: An error occured during the function call. +fn rewrite_projection_given_requirements( + proj: &Projection, + config: &dyn OptimizerConfig, + indices: &[usize], +) -> Result> { + let exprs_used = get_at_indices(&proj.expr, indices); + let required_indices = + indices_referred_by_exprs(proj.input.schema(), exprs_used.iter())?; + return if let Some(input) = + optimize_projections(&proj.input, config, &required_indices)? + { + if &projection_schema(&input, &exprs_used)? == input.schema() { + Ok(Some(input)) + } else { + Projection::try_new(exprs_used, Arc::new(input)) + .map(|proj| Some(LogicalPlan::Projection(proj))) + } + } else if exprs_used.len() < proj.expr.len() { + // Projection expression used is different than the existing projection. + // In this case, even if the child doesn't change, we should update the + // projection to use fewer columns: + if &projection_schema(&proj.input, &exprs_used)? == proj.input.schema() { + Ok(Some(proj.input.as_ref().clone())) + } else { + Projection::try_new(exprs_used, proj.input.clone()) + .map(|proj| Some(LogicalPlan::Projection(proj))) + } + } else { + // Projection doesn't change. + Ok(None) + }; +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::optimize_projections::OptimizeProjections; + use crate::test::{assert_optimized_plan_eq, test_table_scan}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{Result, TableReference}; + use datafusion_expr::{ + binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, not, + table_scan, try_cast, Expr, Like, LogicalPlan, Operator, + }; + + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) + } + + #[test] + fn merge_two_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a")])? + .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? + .build()?; + + let expected = "Projection: Int32(1) + test.a\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn merge_three_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .project(vec![col("a")])? + .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? + .build()?; + + let expected = "Projection: Int32(1) + test.a\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn merge_alias() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a")])? + .project(vec![col("a").alias("alias")])? + .build()?; + + let expected = "Projection: test.a AS alias\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn merge_nested_alias() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").alias("alias1").alias("alias2")])? + .project(vec![col("alias2").alias("alias")])? + .build()?; + + let expected = "Projection: test.a AS alias\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_nested_count() -> Result<()> { + let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]); + + let groups: Vec = vec![]; + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .aggregate(groups.clone(), vec![count(lit(1))]) + .unwrap() + .aggregate(groups, vec![count(lit(1))]) + .unwrap() + .build() + .unwrap(); + + let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + \n Projection: \ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + \n TableScan: ?table? projection=[]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_struct_field_push_down() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new_struct( + "s", + vec![ + Field::new("x", DataType::Int64, false), + Field::new("y", DataType::Int64, false), + ], + false, + ), + ])); + + let table_scan = table_scan(TableReference::none(), &schema, None)?.build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("s").field("x")])? + .build()?; + let expected = "Projection: (?table?.s)[x]\ + \n TableScan: ?table? projection=[s]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_neg_push_down() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![-col("a")])? + .build()?; + + let expected = "Projection: (- test.a)\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_null() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_null()])? + .build()?; + + let expected = "Projection: test.a IS NULL\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_null() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_null()])? + .build()?; + + let expected = "Projection: test.a IS NOT NULL\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_true() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_true()])? + .build()?; + + let expected = "Projection: test.a IS TRUE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_true() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_true()])? + .build()?; + + let expected = "Projection: test.a IS NOT TRUE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_false() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_false()])? + .build()?; + + let expected = "Projection: test.a IS FALSE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_false() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_false()])? + .build()?; + + let expected = "Projection: test.a IS NOT FALSE\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_unknown() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_unknown()])? + .build()?; + + let expected = "Projection: test.a IS UNKNOWN\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_is_not_unknown() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").is_not_unknown()])? + .build()?; + + let expected = "Projection: test.a IS NOT UNKNOWN\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_not() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![not(col("a"))])? + .build()?; + + let expected = "Projection: NOT test.a\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_try_cast() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![try_cast(col("a"), DataType::Float64)])? + .build()?; + + let expected = "Projection: TRY_CAST(test.a AS Float64)\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_similar_to() -> Result<()> { + let table_scan = test_table_scan()?; + let expr = Box::new(col("a")); + let pattern = Box::new(lit("[0-9]")); + let similar_to_expr = + Expr::SimilarTo(Like::new(false, expr, pattern, None, false)); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![similar_to_expr])? + .build()?; + + let expected = "Projection: test.a SIMILAR TO Utf8(\"[0-9]\")\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_between() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").between(lit(1), lit(3))])? + .build()?; + + let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } +} diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index d3bdd47c5cb3..2cb59d511ccf 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -17,6 +17,10 @@ //! Query optimizer traits +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Instant; + use crate::common_subexpr_eliminate::CommonSubexprEliminate; use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; use crate::eliminate_cross_join::EliminateCrossJoin; @@ -24,16 +28,16 @@ use crate::eliminate_duplicated_expr::EliminateDuplicatedExpr; use crate::eliminate_filter::EliminateFilter; use crate::eliminate_join::EliminateJoin; use crate::eliminate_limit::EliminateLimit; +use crate::eliminate_nested_union::EliminateNestedUnion; +use crate::eliminate_one_union::EliminateOneUnion; use crate::eliminate_outer_join::EliminateOuterJoin; -use crate::eliminate_project::EliminateProjection; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; use crate::filter_null_join_keys::FilterNullJoinKeys; -use crate::merge_projection::MergeProjection; +use crate::optimize_projections::OptimizeProjections; use crate::plan_signature::LogicalPlanSignature; use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; -use crate::push_down_projection::PushDownProjection; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; @@ -41,15 +45,14 @@ use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::unwrap_cast_in_comparison::UnwrapCastInComparison; use crate::utils::log_plan; -use chrono::{DateTime, Utc}; + use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; + +use chrono::{DateTime, Utc}; use log::{debug, warn}; -use std::collections::HashSet; -use std::sync::Arc; -use std::time::Instant; /// `OptimizerRule` transforms one [`LogicalPlan`] into another which /// computes the same results, but in a potentially more efficient @@ -220,6 +223,7 @@ impl Optimizer { /// Create a new optimizer using the recommended list of rules pub fn new() -> Self { let rules: Vec> = vec![ + Arc::new(EliminateNestedUnion::new()), Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(ReplaceDistinctWithAggregate::new()), @@ -231,7 +235,6 @@ impl Optimizer { // run it again after running the optimizations that potentially converted // subqueries to joins Arc::new(SimplifyExpressions::new()), - Arc::new(MergeProjection::new()), Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(EliminateDuplicatedExpr::new()), Arc::new(EliminateFilter::new()), @@ -239,6 +242,8 @@ impl Optimizer { Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), Arc::new(PropagateEmptyRelation::new()), + // Must be after PropagateEmptyRelation + Arc::new(EliminateOneUnion::new()), Arc::new(FilterNullJoinKeys::default()), Arc::new(EliminateOuterJoin::new()), // Filters can't be pushed down past Limits, we should do PushDownFilter after PushDownLimit @@ -250,10 +255,7 @@ impl Optimizer { Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(CommonSubexprEliminate::new()), - Arc::new(PushDownProjection::new()), - Arc::new(EliminateProjection::new()), - // PushDownProjection can pushdown Projections through Limits, do PushDownLimit again. - Arc::new(PushDownLimit::new()), + Arc::new(OptimizeProjections::new()), ]; Self::with_rules(rules) @@ -380,7 +382,7 @@ impl Optimizer { }) .collect::>(); - Ok(Some(plan.with_new_inputs(&new_inputs)?)) + Ok(Some(plan.with_new_exprs(plan.expressions(), &new_inputs)?)) } /// Use a rule to optimize the whole plan. @@ -422,7 +424,7 @@ impl Optimizer { /// Returns an error if plans have different schemas. /// /// It ignores metadata and nullability. -fn assert_schema_is_the_same( +pub(crate) fn assert_schema_is_the_same( rule_name: &str, prev_plan: &LogicalPlan, new_plan: &LogicalPlan, @@ -433,7 +435,7 @@ fn assert_schema_is_the_same( if !equivalent { let e = DataFusionError::Internal(format!( - "Failed due to generate a different schema, original schema: {:?}, new schema: {:?}", + "Failed due to a difference in schemas, original schema: {:?}, new schema: {:?}", prev_plan.schema(), new_plan.schema() )); @@ -448,17 +450,18 @@ fn assert_schema_is_the_same( #[cfg(test)] mod tests { + use std::sync::{Arc, Mutex}; + + use super::ApplyOrder; use crate::optimizer::Optimizer; use crate::test::test_table_scan; use crate::{OptimizerConfig, OptimizerContext, OptimizerRule}; + use datafusion_common::{ plan_err, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; use datafusion_expr::logical_plan::EmptyRelation; use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, Projection}; - use std::sync::{Arc, Mutex}; - - use super::ApplyOrder; #[test] fn skip_failing_rule() { @@ -498,7 +501,7 @@ mod tests { let err = opt.optimize(&plan, &config, &observe).unwrap_err(); assert_eq!( "Optimizer rule 'get table_scan rule' failed\ncaused by\nget table_scan rule\ncaused by\n\ - Internal error: Failed due to generate a different schema, \ + Internal error: Failed due to a difference in schemas, \ original schema: DFSchema { fields: [], metadata: {}, functional_dependencies: FunctionalDependencies { deps: [] } }, \ new schema: DFSchema { fields: [\ DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, \ diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 4de7596b329c..040b69fc8bf3 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -182,12 +182,11 @@ fn empty_child(plan: &LogicalPlan) -> Result> { #[cfg(test)] mod tests { use crate::eliminate_filter::EliminateFilter; - use crate::optimizer::Optimizer; + use crate::eliminate_nested_union::EliminateNestedUnion; use crate::test::{ - assert_optimized_plan_eq, test_table_scan, test_table_scan_fields, - test_table_scan_with_name, + assert_optimized_plan_eq, assert_optimized_plan_eq_with_rules, test_table_scan, + test_table_scan_fields, test_table_scan_with_name, }; - use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Column, DFField, DFSchema, ScalarValue}; use datafusion_expr::logical_plan::table_scan; @@ -206,21 +205,15 @@ mod tests { plan: &LogicalPlan, expected: &str, ) -> Result<()> { - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - let optimizer = Optimizer::with_rules(vec![ - Arc::new(EliminateFilter::new()), - Arc::new(PropagateEmptyRelation::new()), - ]); - let config = &mut OptimizerContext::new() - .with_max_passes(1) - .with_skip_failing_rules(false); - let optimized_plan = optimizer - .optimize(plan, config, observe) - .expect("failed to optimize plan"); - let formatted_plan = format!("{optimized_plan:?}"); - assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); - Ok(()) + assert_optimized_plan_eq_with_rules( + vec![ + Arc::new(EliminateFilter::new()), + Arc::new(EliminateNestedUnion::new()), + Arc::new(PropagateEmptyRelation::new()), + ], + plan, + expected, + ) } #[test] diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index fe726d5d7783..4eb925ac0629 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -12,49 +12,120 @@ // specific language governing permissions and limitations // under the License. -//! Push Down Filter optimizer rule ensures that filters are applied as early as possible in the plan +//! [`PushDownFilter`] Moves filters so they are applied as early as possible in +//! the plan. + +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, split_conjunction}; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; -use datafusion_common::{internal_err, Column, DFSchema, DataFusionError, Result}; +use datafusion_common::{ + internal_err, plan_datafusion_err, Column, DFSchema, DFSchemaRef, DataFusionError, + JoinConstraint, Result, +}; use datafusion_expr::expr::Alias; +use datafusion_expr::expr_rewriter::replace_col; +use datafusion_expr::logical_plan::{ + CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union, +}; +use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned}; use datafusion_expr::{ - and, - expr_rewriter::replace_col, - logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union}, - or, BinaryExpr, Expr, Filter, Operator, TableProviderFilterPushDown, + and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator, + ScalarFunctionDefinition, TableProviderFilterPushDown, Volatility, }; + use itertools::Itertools; -use std::collections::{HashMap, HashSet}; -use std::sync::Arc; -/// Push Down Filter optimizer rule pushes filter clauses down the plan +/// Optimizer rule for pushing (moving) filter expressions down in a plan so +/// they are applied as early as possible. +/// /// # Introduction -/// A filter-commutative operation is an operation whose result of filter(op(data)) = op(filter(data)). -/// An example of a filter-commutative operation is a projection; a counter-example is `limit`. /// -/// The filter-commutative property is column-specific. An aggregate grouped by A on SUM(B) -/// can commute with a filter that depends on A only, but does not commute with a filter that depends -/// on SUM(B). +/// The goal of this rule is to improve query performance by eliminating +/// redundant work. +/// +/// For example, given a plan that sorts all values where `a > 10`: +/// +/// ```text +/// Filter (a > 10) +/// Sort (a, b) +/// ``` +/// +/// A better plan is to filter the data *before* the Sort, which sorts fewer +/// rows and therefore does less work overall: +/// +/// ```text +/// Sort (a, b) +/// Filter (a > 10) <-- Filter is moved before the sort +/// ``` +/// +/// However it is not always possible to push filters down. For example, given a +/// plan that finds the top 3 values and then keeps only those that are greater +/// than 10, if the filter is pushed below the limit it would produce a +/// different result. /// -/// This optimizer commutes filters with filter-commutative operations to push the filters -/// the closest possible to the scans, re-writing the filter expressions by every -/// projection that changes the filter's expression. +/// ```text +/// Filter (a > 10) <-- can not move this Filter before the limit +/// Limit (fetch=3) +/// Sort (a, b) +/// ``` /// -/// Filter: b Gt Int64(10) -/// Projection: a AS b /// -/// is optimized to +/// More formally, a filter-commutative operation is an operation `op` that +/// satisfies `filter(op(data)) = op(filter(data))`. /// -/// Projection: a AS b -/// Filter: a Gt Int64(10) <--- changed from b to a +/// The filter-commutative property is plan and column-specific. A filter on `a` +/// can be pushed through a `Aggregate(group_by = [a], agg=[SUM(b))`. However, a +/// filter on `SUM(b)` can not be pushed through the same aggregate. /// -/// This performs a single pass through the plan. When it passes through a filter, it stores that filter, -/// and when it reaches a node that does not commute with it, it adds the filter to that place. -/// When it passes through a projection, it re-writes the filter's expression taking into account that projection. -/// When multiple filters would have been written, it `AND` their expressions into a single expression. +/// # Handling Conjunctions +/// +/// It is possible to only push down **part** of a filter expression if is +/// connected with `AND`s (more formally if it is a "conjunction"). +/// +/// For example, given the following plan: +/// +/// ```text +/// Filter(a > 10 AND SUM(b) < 5) +/// Aggregate(group_by = [a], agg = [SUM(b)) +/// ``` +/// +/// The `a > 10` is commutative with the `Aggregate` but `SUM(b) < 5` is not. +/// Therefore it is possible to only push part of the expression, resulting in: +/// +/// ```text +/// Filter(SUM(b) < 5) +/// Aggregate(group_by = [a], agg = [SUM(b)) +/// Filter(a > 10) +/// ``` +/// +/// # Handling Column Aliases +/// +/// This optimizer must sometimes handle re-writing filter expressions when they +/// pushed, for example if there is a projection that aliases `a+1` to `"b"`: +/// +/// ```text +/// Filter (b > 10) +/// Projection: [a+1 AS "b"] <-- changes the name of `a+1` to `b` +/// ``` +/// +/// To apply the filter prior to the `Projection`, all references to `b` must be +/// rewritten to `a+1`: +/// +/// ```text +/// Projection: a AS "b" +/// Filter: (a + 1 > 10) <--- changed from b to a + 1 +/// ``` +/// # Implementation Notes +/// +/// This implementation performs a single pass through the plan, "pushing" down +/// filters. When it passes through a filter, it stores that filter, and when it +/// reaches a plan node that does not commute with that filter, it adds the +/// filter to that place. When it passes through a projection, it re-writes the +/// filter's expression taking into account that projection. #[derive(Default)] pub struct PushDownFilter {} @@ -155,7 +226,10 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::OuterReferenceColumn(_, _) - | Expr::ScalarUDF(..) => { + | Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(_), + .. + }) => { is_evaluate = false; Ok(VisitRecursion::Stop) } @@ -183,9 +257,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Expr::Sort(_) | Expr::AggregateFunction(_) | Expr::WindowFunction(_) - | Expr::AggregateUDF { .. } - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard { .. } | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), })?; Ok(is_evaluate) @@ -477,9 +549,7 @@ fn push_down_join( parent_predicate: Option<&Expr>, ) -> Result> { let predicates = match parent_predicate { - Some(parent_predicate) => { - utils::split_conjunction_owned(parent_predicate.clone()) - } + Some(parent_predicate) => split_conjunction_owned(parent_predicate.clone()), None => vec![], }; @@ -487,12 +557,21 @@ fn push_down_join( let on_filters = join .filter .as_ref() - .map(|e| utils::split_conjunction_owned(e.clone())) - .unwrap_or_else(Vec::new); + .map(|e| split_conjunction_owned(e.clone())) + .unwrap_or_default(); let mut is_inner_join = false; let infer_predicates = if join.join_type == JoinType::Inner { is_inner_join = true; + // Only allow both side key is column. + let join_col_keys = join + .on + .iter() + .flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) { + (Ok(l_col), Ok(r_col)) => Some((l_col, r_col)), + _ => None, + }) + .collect::>(); // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down // For inner joins, duplicate filters for joined columns so filters can be pushed down // to both sides. Take the following query as an example: @@ -517,16 +596,6 @@ fn push_down_join( Err(e) => return Some(Err(e)), }; - // Only allow both side key is column. - let join_col_keys = join - .on - .iter() - .flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) { - (Ok(l_col), Ok(r_col)) => Some((l_col, r_col)), - _ => None, - }) - .collect::>(); - for col in columns.iter() { for (l, r) in join_col_keys.iter() { if col == l { @@ -609,7 +678,7 @@ impl OptimizerRule for PushDownFilter { .map(|e| (*e).clone()) .collect::>(); let new_predicate = conjunction(new_predicates).ok_or_else(|| { - DataFusionError::Plan("at least one expression exists".to_string()) + plan_datafusion_err!("at least one expression exists") })?; let new_filter = LogicalPlan::Filter(Filter::try_new( new_predicate, @@ -622,9 +691,11 @@ impl OptimizerRule for PushDownFilter { | LogicalPlan::Distinct(_) | LogicalPlan::Sort(_) => { // commutable - let new_filter = - plan.with_new_inputs(&[child_plan.inputs()[0].clone()])?; - child_plan.with_new_inputs(&[new_filter])? + let new_filter = plan.with_new_exprs( + plan.expressions(), + &[child_plan.inputs()[0].clone()], + )?; + child_plan.with_new_exprs(child_plan.expressions(), &[new_filter])? } LogicalPlan::SubqueryAlias(subquery_alias) => { let mut replace_map = HashMap::new(); @@ -647,35 +718,68 @@ impl OptimizerRule for PushDownFilter { new_predicate, subquery_alias.input.clone(), )?); - child_plan.with_new_inputs(&[new_filter])? + child_plan.with_new_exprs(child_plan.expressions(), &[new_filter])? } LogicalPlan::Projection(projection) => { - // A projection is filter-commutable, but re-writes all predicate expressions + // A projection is filter-commutable if it do not contain volatile predicates or contain volatile + // predicates that are not used in the filter. However, we should re-writes all predicate expressions. // collect projection. - let replace_map = projection - .schema - .fields() - .iter() - .enumerate() - .map(|(i, field)| { - // strip alias, as they should not be part of filters - let expr = match &projection.expr[i] { - Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(), - expr => expr.clone(), - }; - - (field.qualified_name(), expr) - }) - .collect::>(); - - // re-write all filters based on this projection - // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" - let new_filter = LogicalPlan::Filter(Filter::try_new( - replace_cols_by_name(filter.predicate.clone(), &replace_map)?, - projection.input.clone(), - )?); + let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = + projection + .schema + .fields() + .iter() + .enumerate() + .map(|(i, field)| { + // strip alias, as they should not be part of filters + let expr = match &projection.expr[i] { + Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(), + expr => expr.clone(), + }; + + (field.qualified_name(), expr) + }) + .partition(|(_, value)| is_volatile_expression(value)); + + let mut push_predicates = vec![]; + let mut keep_predicates = vec![]; + for expr in split_conjunction_owned(filter.predicate.clone()).into_iter() + { + if contain(&expr, &volatile_map) { + keep_predicates.push(expr); + } else { + push_predicates.push(expr); + } + } - child_plan.with_new_inputs(&[new_filter])? + match conjunction(push_predicates) { + Some(expr) => { + // re-write all filters based on this projection + // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" + let new_filter = LogicalPlan::Filter(Filter::try_new( + replace_cols_by_name(expr, &non_volatile_map)?, + projection.input.clone(), + )?); + + match conjunction(keep_predicates) { + None => child_plan.with_new_exprs( + child_plan.expressions(), + &[new_filter], + )?, + Some(keep_predicate) => { + let child_plan = child_plan.with_new_exprs( + child_plan.expressions(), + &[new_filter], + )?; + LogicalPlan::Filter(Filter::try_new( + keep_predicate, + Arc::new(child_plan), + )?) + } + } + } + None => return Ok(None), + } } LogicalPlan::Union(union) => { let mut inputs = Vec::with_capacity(union.inputs.len()); @@ -708,7 +812,7 @@ impl OptimizerRule for PushDownFilter { .map(|e| Ok(Column::from_qualified_name(e.display_name()?))) .collect::>>()?; - let predicates = utils::split_conjunction_owned(filter.predicate.clone()); + let predicates = split_conjunction_owned(filter.predicate.clone()); let mut keep_predicates = vec![]; let mut push_predicates = vec![]; @@ -740,7 +844,9 @@ impl OptimizerRule for PushDownFilter { )?), None => (*agg.input).clone(), }; - let new_agg = filter.input.with_new_inputs(&vec![child])?; + let new_agg = filter + .input + .with_new_exprs(filter.input.expressions(), &vec![child])?; match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( predicate, @@ -755,17 +861,23 @@ impl OptimizerRule for PushDownFilter { None => return Ok(None), } } - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - let predicates = utils::split_conjunction_owned(filter.predicate.clone()); - push_down_all_join( + LogicalPlan::CrossJoin(cross_join) => { + let predicates = split_conjunction_owned(filter.predicate.clone()); + let join = convert_cross_join_to_inner_join(cross_join.clone())?; + let join_plan = LogicalPlan::Join(join); + let inputs = join_plan.inputs(); + let left = inputs[0]; + let right = inputs[1]; + let plan = push_down_all_join( predicates, vec![], - &filter.input, + &join_plan, left, right, vec![], - false, - )? + true, + )?; + convert_to_cross_join_if_beneficial(plan)? } LogicalPlan::TableScan(scan) => { let filter_predicates = split_conjunction(&filter.predicate); @@ -811,7 +923,7 @@ impl OptimizerRule for PushDownFilter { let prevent_cols = extension_plan.node.prevent_predicate_push_down_columns(); - let predicates = utils::split_conjunction_owned(filter.predicate.clone()); + let predicates = split_conjunction_owned(filter.predicate.clone()); let mut keep_predicates = vec![]; let mut push_predicates = vec![]; @@ -839,7 +951,8 @@ impl OptimizerRule for PushDownFilter { None => extension_plan.node.inputs().into_iter().cloned().collect(), }; // extension with new inputs. - let new_extension = child_plan.with_new_inputs(&new_children)?; + let new_extension = + child_plan.with_new_exprs(child_plan.expressions(), &new_children)?; match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( @@ -862,6 +975,42 @@ impl PushDownFilter { } } +/// Converts the given cross join to an inner join with an empty equality +/// predicate and an empty filter condition. +fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { + let CrossJoin { left, right, .. } = cross_join; + let join_schema = build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; + Ok(Join { + left, + right, + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + on: vec![], + filter: None, + schema: DFSchemaRef::new(join_schema), + null_equals_null: true, + }) +} + +/// Converts the given inner join with an empty equality predicate and an +/// empty filter condition to a cross join. +fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result { + if let LogicalPlan::Join(join) = &plan { + // Can be converted back to cross join + if join.on.is_empty() && join.filter.is_none() { + return LogicalPlanBuilder::from(join.left.as_ref().clone()) + .cross_join(join.right.as_ref().clone())? + .build(); + } + } else if let LogicalPlan::Filter(filter) = &plan { + let new_input = + convert_to_cross_join_if_beneficial(filter.input.as_ref().clone())?; + return Filter::try_new(filter.predicate.clone(), Arc::new(new_input)) + .map(LogicalPlan::Filter); + } + Ok(plan) +} + /// replaces columns by its name on the projection. pub fn replace_cols_by_name( e: Expr, @@ -879,24 +1028,79 @@ pub fn replace_cols_by_name( }) } +/// check whether the expression is volatile predicates +fn is_volatile_expression(e: &Expr) -> bool { + let mut is_volatile = false; + e.apply(&mut |expr| { + Ok(match expr { + Expr::ScalarFunction(f) => match &f.func_def { + ScalarFunctionDefinition::BuiltIn(fun) + if fun.volatility() == Volatility::Volatile => + { + is_volatile = true; + VisitRecursion::Stop + } + ScalarFunctionDefinition::UDF(fun) + if fun.signature().volatility == Volatility::Volatile => + { + is_volatile = true; + VisitRecursion::Stop + } + ScalarFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + _ => VisitRecursion::Continue, + }, + _ => VisitRecursion::Continue, + }) + }) + .unwrap(); + is_volatile +} + +/// check whether the expression uses the columns in `check_map`. +fn contain(e: &Expr, check_map: &HashMap) -> bool { + let mut is_contain = false; + e.apply(&mut |expr| { + Ok(if let Expr::Column(c) = &expr { + match check_map.get(&c.flat_name()) { + Some(_) => { + is_contain = true; + VisitRecursion::Stop + } + None => VisitRecursion::Continue, + } + } else { + VisitRecursion::Continue + }) + }) + .unwrap(); + is_contain +} + #[cfg(test)] mod tests { + use std::fmt::{Debug, Formatter}; + use std::sync::Arc; + use super::*; use crate::optimizer::Optimizer; use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::test::*; use crate::OptimizerContext; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use async_trait::async_trait; use datafusion_common::{DFSchema, DFSchemaRef}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ - and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, BinaryExpr, - Expr, Extension, LogicalPlanBuilder, Operator, TableSource, TableType, - UserDefinedLogicalNodeCore, + and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, random, sum, + BinaryExpr, Expr, Extension, LogicalPlanBuilder, Operator, TableSource, + TableType, UserDefinedLogicalNodeCore, }; - use std::fmt::{Debug, Formatter}; - use std::sync::Arc; + + use async_trait::async_trait; fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { crate::test::assert_optimized_plan_eq( @@ -916,7 +1120,7 @@ mod tests { ]); let mut optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? @@ -2520,14 +2724,12 @@ Projection: a, b .cross_join(right)? .filter(filter)? .build()?; - let expected = "\ - Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ - \n CrossJoin:\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ - \n Projection: test1.a AS d, test1.a AS e\ - \n TableScan: test1"; + Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ + \n Projection: test.a, test.b, test.c\ + \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ + \n Projection: test1.a AS d, test1.a AS e\ + \n TableScan: test1"; assert_optimized_plan_eq_with_rewrite_predicate(&plan, expected)?; // Originally global state which can help to avoid duplicate Filters been generated and pushed down. @@ -2710,4 +2912,79 @@ Projection: a, b \n TableScan: test2"; assert_optimized_plan_eq(&plan, expected) } + + #[test] + fn test_push_down_volatile_function_in_aggregate() -> Result<()> { + // SELECT t.a, t.r FROM (SELECT a, SUM(b), random()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5; + let table_scan = test_table_scan_with_name("test1")?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![sum(col("b"))])? + .project(vec![ + col("a"), + sum(col("b")), + add(random(), lit(1)).alias("r"), + ])? + .alias("t")? + .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))? + .project(vec![col("t.a"), col("t.r")])? + .build()?; + + let expected_before = "Projection: t.a, t.r\ + \n Filter: t.a > Int32(5) AND t.r > Float64(0.5)\ + \n SubqueryAlias: t\ + \n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\ + \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\ + \n TableScan: test1"; + assert_eq!(format!("{plan:?}"), expected_before); + + let expected_after = "Projection: t.a, t.r\ + \n SubqueryAlias: t\ + \n Filter: r > Float64(0.5)\ + \n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\ + \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\ + \n TableScan: test1, full_filters=[test1.a > Int32(5)]"; + assert_optimized_plan_eq(&plan, expected_after) + } + + #[test] + fn test_push_down_volatile_function_in_join() -> Result<()> { + // SELECT t.a, t.r FROM (SELECT test1.a AS a, random() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5; + let table_scan = test_table_scan_with_name("test1")?; + let left = LogicalPlanBuilder::from(table_scan).build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan).build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::Inner, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .project(vec![col("test1.a").alias("a"), random().alias("r")])? + .alias("t")? + .filter(col("t.r").gt(lit(0.8)))? + .project(vec![col("t.a"), col("t.r")])? + .build()?; + + let expected_before = "Projection: t.a, t.r\ + \n Filter: t.r > Float64(0.8)\ + \n SubqueryAlias: t\ + \n Projection: test1.a AS a, random() AS r\ + \n Inner Join: test1.a = test2.a\ + \n TableScan: test1\ + \n TableScan: test2"; + assert_eq!(format!("{plan:?}"), expected_before); + + let expected = "Projection: t.a, t.r\ + \n SubqueryAlias: t\ + \n Filter: r > Float64(0.8)\ + \n Projection: test1.a AS a, random() AS r\ + \n Inner Join: test1.a = test2.a\ + \n TableScan: test1\ + \n TableScan: test2"; + assert_optimized_plan_eq(&plan, expected) + } } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 6703a1d787a7..c2f35a790616 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -126,7 +126,7 @@ impl OptimizerRule for PushDownLimit { fetch: scan.fetch.map(|x| min(x, limit)).or(Some(limit)), projected_schema: scan.projected_schema.clone(), }); - Some(plan.with_new_inputs(&[new_input])?) + Some(plan.with_new_exprs(plan.expressions(), &[new_input])?) } } LogicalPlan::Union(union) => { @@ -145,7 +145,7 @@ impl OptimizerRule for PushDownLimit { inputs: new_inputs, schema: union.schema.clone(), }); - Some(plan.with_new_inputs(&[union])?) + Some(plan.with_new_exprs(plan.expressions(), &[union])?) } LogicalPlan::CrossJoin(cross_join) => { @@ -166,15 +166,16 @@ impl OptimizerRule for PushDownLimit { right: Arc::new(new_right), schema: plan.schema().clone(), }); - Some(plan.with_new_inputs(&[new_cross_join])?) + Some(plan.with_new_exprs(plan.expressions(), &[new_cross_join])?) } LogicalPlan::Join(join) => { let new_join = push_down_join(join, fetch + skip); match new_join { - Some(new_join) => { - Some(plan.with_new_inputs(&[LogicalPlan::Join(new_join)])?) - } + Some(new_join) => Some(plan.with_new_exprs( + plan.expressions(), + &[LogicalPlan::Join(new_join)], + )?), None => None, } } @@ -192,14 +193,16 @@ impl OptimizerRule for PushDownLimit { input: Arc::new((*sort.input).clone()), fetch: new_fetch, }); - Some(plan.with_new_inputs(&[new_sort])?) + Some(plan.with_new_exprs(plan.expressions(), &[new_sort])?) } } LogicalPlan::Projection(_) | LogicalPlan::SubqueryAlias(_) => { // commute - let new_limit = - plan.with_new_inputs(&[child_plan.inputs()[0].clone()])?; - Some(child_plan.with_new_inputs(&[new_limit])?) + let new_limit = plan.with_new_exprs( + plan.expressions(), + &[child_plan.inputs()[0].clone()], + )?; + Some(child_plan.with_new_exprs(child_plan.expressions(), &[new_limit])?) } _ => None, }; diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 6db4bb9ba405..4ee4f7e417a6 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -18,564 +18,27 @@ //! Projection Push Down optimizer rule ensures that only referenced columns are //! loaded into memory -use crate::eliminate_project::can_eliminate; -use crate::merge_projection::merge_projection; -use crate::optimizer::ApplyOrder; -use crate::push_down_filter::replace_cols_by_name; -use crate::{OptimizerConfig, OptimizerRule}; -use arrow::error::Result as ArrowResult; -use datafusion_common::ScalarValue::UInt8; -use datafusion_common::{ - plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ToDFSchema, -}; -use datafusion_expr::expr::{AggregateFunction, Alias}; -use datafusion_expr::utils::exprlist_to_fields; -use datafusion_expr::{ - logical_plan::{Aggregate, LogicalPlan, Projection, TableScan, Union}, - utils::{expr_to_columns, exprlist_to_columns}, - Expr, LogicalPlanBuilder, SubqueryAlias, -}; -use std::collections::HashMap; -use std::{ - collections::{BTreeSet, HashSet}, - sync::Arc, -}; - -// if projection is empty return projection-new_plan, else return new_plan. -#[macro_export] -macro_rules! generate_plan { - ($projection_is_empty:expr, $plan:expr, $new_plan:expr) => { - if $projection_is_empty { - $new_plan - } else { - $plan.with_new_inputs(&[$new_plan])? - } - }; -} - -/// Optimizer that removes unused projections and aggregations from plans -/// This reduces both scans and -#[derive(Default)] -pub struct PushDownProjection {} - -impl OptimizerRule for PushDownProjection { - fn try_optimize( - &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - let projection = match plan { - LogicalPlan::Projection(projection) => projection, - LogicalPlan::Aggregate(agg) => { - let mut required_columns = HashSet::new(); - for e in agg.aggr_expr.iter().chain(agg.group_expr.iter()) { - expr_to_columns(e, &mut required_columns)? - } - let new_expr = get_expr(&required_columns, agg.input.schema())?; - let projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - agg.input.clone(), - )?); - let optimized_child = self - .try_optimize(&projection, _config)? - .unwrap_or(projection); - return Ok(Some(plan.with_new_inputs(&[optimized_child])?)); - } - LogicalPlan::TableScan(scan) if scan.projection.is_none() => { - return Ok(Some(push_down_scan(&HashSet::new(), scan, false)?)); - } - _ => return Ok(None), - }; - - let child_plan = &*projection.input; - let projection_is_empty = projection.expr.is_empty(); - - let new_plan = match child_plan { - LogicalPlan::Projection(child_projection) => { - let new_plan = merge_projection(projection, child_projection)?; - self.try_optimize(&new_plan, _config)?.unwrap_or(new_plan) - } - LogicalPlan::Join(join) => { - // collect column in on/filter in join and projection. - let mut push_columns: HashSet = HashSet::new(); - for e in projection.expr.iter() { - expr_to_columns(e, &mut push_columns)?; - } - for (l, r) in join.on.iter() { - expr_to_columns(l, &mut push_columns)?; - expr_to_columns(r, &mut push_columns)?; - } - if let Some(expr) = &join.filter { - expr_to_columns(expr, &mut push_columns)?; - } - - let new_left = generate_projection( - &push_columns, - join.left.schema(), - join.left.clone(), - )?; - let new_right = generate_projection( - &push_columns, - join.right.schema(), - join.right.clone(), - )?; - let new_join = child_plan.with_new_inputs(&[new_left, new_right])?; - - generate_plan!(projection_is_empty, plan, new_join) - } - LogicalPlan::CrossJoin(join) => { - // collect column in on/filter in join and projection. - let mut push_columns: HashSet = HashSet::new(); - for e in projection.expr.iter() { - expr_to_columns(e, &mut push_columns)?; - } - let new_left = generate_projection( - &push_columns, - join.left.schema(), - join.left.clone(), - )?; - let new_right = generate_projection( - &push_columns, - join.right.schema(), - join.right.clone(), - )?; - let new_join = child_plan.with_new_inputs(&[new_left, new_right])?; - - generate_plan!(projection_is_empty, plan, new_join) - } - LogicalPlan::TableScan(scan) - if !scan.projected_schema.fields().is_empty() => - { - let mut used_columns: HashSet = HashSet::new(); - if projection_is_empty { - used_columns - .insert(scan.projected_schema.fields()[0].qualified_column()); - push_down_scan(&used_columns, scan, true)? - } else { - for expr in projection.expr.iter() { - expr_to_columns(expr, &mut used_columns)?; - } - let new_scan = push_down_scan(&used_columns, scan, true)?; - - plan.with_new_inputs(&[new_scan])? - } - } - LogicalPlan::Values(values) if projection_is_empty => { - let first_col = - Expr::Column(values.schema.fields()[0].qualified_column()); - LogicalPlan::Projection(Projection::try_new( - vec![first_col], - Arc::new(child_plan.clone()), - )?) - } - LogicalPlan::Union(union) => { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - // When there is no projection, we need to add the first column to the projection - // Because if push empty down, children may output different columns. - if required_columns.is_empty() { - required_columns.insert(union.schema.fields()[0].qualified_column()); - } - // we don't push down projection expr, we just prune columns, so we just push column - // because push expr may cause more cost. - let projection_column_exprs = get_expr(&required_columns, &union.schema)?; - let mut inputs = Vec::with_capacity(union.inputs.len()); - for input in &union.inputs { - let mut replace_map = HashMap::new(); - for (i, field) in input.schema().fields().iter().enumerate() { - replace_map.insert( - union.schema.fields()[i].qualified_name(), - Expr::Column(field.qualified_column()), - ); - } - - let exprs = projection_column_exprs - .iter() - .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) - .collect::>>()?; - - inputs.push(Arc::new(LogicalPlan::Projection(Projection::try_new( - exprs, - input.clone(), - )?))) - } - // create schema of all used columns - let schema = DFSchema::new_with_metadata( - exprlist_to_fields(&projection_column_exprs, child_plan)?, - union.schema.metadata().clone(), - )?; - let new_union = LogicalPlan::Union(Union { - inputs, - schema: Arc::new(schema), - }); - - generate_plan!(projection_is_empty, plan, new_union) - } - LogicalPlan::SubqueryAlias(subquery_alias) => { - let replace_map = generate_column_replace_map(subquery_alias); - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - - let new_required_columns = required_columns - .iter() - .map(|c| { - replace_map.get(c).cloned().ok_or_else(|| { - DataFusionError::Internal("replace column failed".to_string()) - }) - }) - .collect::>>()?; - - let new_expr = - get_expr(&new_required_columns, subquery_alias.input.schema())?; - let new_projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - subquery_alias.input.clone(), - )?); - let new_alias = child_plan.with_new_inputs(&[new_projection])?; - - generate_plan!(projection_is_empty, plan, new_alias) - } - LogicalPlan::Aggregate(agg) => { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - // Gather all columns needed for expressions in this Aggregate - let mut new_aggr_expr = vec![]; - for e in agg.aggr_expr.iter() { - let column = Column::from_name(e.display_name()?); - if required_columns.contains(&column) { - new_aggr_expr.push(e.clone()); - } - } - - // if new_aggr_expr emtpy and aggr is COUNT(UInt8(1)), push it - if new_aggr_expr.is_empty() && agg.aggr_expr.len() == 1 { - if let Expr::AggregateFunction(AggregateFunction { - fun, args, .. - }) = &agg.aggr_expr[0] - { - if matches!(fun, datafusion_expr::AggregateFunction::Count) - && args.len() == 1 - && args[0] == Expr::Literal(UInt8(Some(1))) - { - new_aggr_expr.push(agg.aggr_expr[0].clone()); - } - } - } - - let new_agg = LogicalPlan::Aggregate(Aggregate::try_new( - agg.input.clone(), - agg.group_expr.clone(), - new_aggr_expr, - )?); - - generate_plan!(projection_is_empty, plan, new_agg) - } - LogicalPlan::Window(window) => { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - // Gather all columns needed for expressions in this Window - let mut new_window_expr = vec![]; - for e in window.window_expr.iter() { - let column = Column::from_name(e.display_name()?); - if required_columns.contains(&column) { - new_window_expr.push(e.clone()); - } - } - - if new_window_expr.is_empty() { - // none columns in window expr are needed, remove the window expr - let input = window.input.clone(); - let new_window = restrict_outputs(input.clone(), &required_columns)? - .unwrap_or((*input).clone()); - - generate_plan!(projection_is_empty, plan, new_window) - } else { - let mut referenced_inputs = HashSet::new(); - exprlist_to_columns(&new_window_expr, &mut referenced_inputs)?; - window - .input - .schema() - .fields() - .iter() - .filter(|f| required_columns.contains(&f.qualified_column())) - .for_each(|f| { - referenced_inputs.insert(f.qualified_column()); - }); - - let input = window.input.clone(); - let new_input = restrict_outputs(input.clone(), &referenced_inputs)? - .unwrap_or((*input).clone()); - let new_window = LogicalPlanBuilder::from(new_input) - .window(new_window_expr)? - .build()?; - - generate_plan!(projection_is_empty, plan, new_window) - } - } - LogicalPlan::Filter(filter) => { - if can_eliminate(projection, child_plan.schema()) { - // when projection schema == filter schema, we can commute directly. - let new_proj = - plan.with_new_inputs(&[filter.input.as_ref().clone()])?; - child_plan.with_new_inputs(&[new_proj])? - } else { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - exprlist_to_columns( - &[filter.predicate.clone()], - &mut required_columns, - )?; - - let new_expr = get_expr(&required_columns, filter.input.schema())?; - let new_projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - filter.input.clone(), - )?); - let new_filter = child_plan.with_new_inputs(&[new_projection])?; - - generate_plan!(projection_is_empty, plan, new_filter) - } - } - LogicalPlan::Sort(sort) => { - if can_eliminate(projection, child_plan.schema()) { - // can commute - let new_proj = plan.with_new_inputs(&[(*sort.input).clone()])?; - child_plan.with_new_inputs(&[new_proj])? - } else { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - exprlist_to_columns(&sort.expr, &mut required_columns)?; - - let new_expr = get_expr(&required_columns, sort.input.schema())?; - let new_projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - sort.input.clone(), - )?); - let new_sort = child_plan.with_new_inputs(&[new_projection])?; - - generate_plan!(projection_is_empty, plan, new_sort) - } - } - LogicalPlan::Limit(limit) => { - // can commute - let new_proj = plan.with_new_inputs(&[limit.input.as_ref().clone()])?; - child_plan.with_new_inputs(&[new_proj])? - } - _ => return Ok(None), - }; - - Ok(Some(new_plan)) - } - - fn name(&self) -> &str { - "push_down_projection" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } -} - -impl PushDownProjection { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -fn generate_column_replace_map( - subquery_alias: &SubqueryAlias, -) -> HashMap { - subquery_alias - .input - .schema() - .fields() - .iter() - .enumerate() - .map(|(i, field)| { - ( - subquery_alias.schema.fields()[i].qualified_column(), - field.qualified_column(), - ) - }) - .collect() -} - -pub fn collect_projection_expr(projection: &Projection) -> HashMap { - projection - .schema - .fields() - .iter() - .enumerate() - .flat_map(|(i, field)| { - // strip alias, as they should not be part of filters - let expr = match &projection.expr[i] { - Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(), - expr => expr.clone(), - }; - - // Convert both qualified and unqualified fields - [ - (field.name().clone(), expr.clone()), - (field.qualified_name(), expr), - ] - }) - .collect::>() -} - -// Get the projection exprs from columns in the order of the schema -fn get_expr(columns: &HashSet, schema: &DFSchemaRef) -> Result> { - let expr = schema - .fields() - .iter() - .flat_map(|field| { - let qc = field.qualified_column(); - let uqc = field.unqualified_column(); - if columns.contains(&qc) || columns.contains(&uqc) { - Some(Expr::Column(qc)) - } else { - None - } - }) - .collect::>(); - if columns.len() != expr.len() { - plan_err!("required columns can't push down, columns: {columns:?}") - } else { - Ok(expr) - } -} - -fn generate_projection( - used_columns: &HashSet, - schema: &DFSchemaRef, - input: Arc, -) -> Result { - let expr = schema - .fields() - .iter() - .flat_map(|field| { - let column = field.qualified_column(); - if used_columns.contains(&column) { - Some(Expr::Column(column)) - } else { - None - } - }) - .collect::>(); - - Ok(LogicalPlan::Projection(Projection::try_new(expr, input)?)) -} - -fn push_down_scan( - used_columns: &HashSet, - scan: &TableScan, - has_projection: bool, -) -> Result { - // once we reach the table scan, we can use the accumulated set of column - // names to construct the set of column indexes in the scan - // - // we discard non-existing columns because some column names are not part of the schema, - // e.g. when the column derives from an aggregation - // - // Use BTreeSet to remove potential duplicates (e.g. union) as - // well as to sort the projection to ensure deterministic behavior - let schema = scan.source.schema(); - let mut projection: BTreeSet = used_columns - .iter() - .filter(|c| { - c.relation.is_none() || c.relation.as_ref().unwrap() == &scan.table_name - }) - .map(|c| schema.index_of(&c.name)) - .filter_map(ArrowResult::ok) - .collect(); - - if projection.is_empty() { - if has_projection && !schema.fields().is_empty() { - // Ensure that we are reading at least one column from the table in case the query - // does not reference any columns directly such as "SELECT COUNT(1) FROM table", - // except when the table is empty (no column) - projection.insert(0); - } else { - // for table scan without projection, we default to return all columns - projection = scan - .source - .schema() - .fields() - .iter() - .enumerate() - .map(|(i, _)| i) - .collect::>(); - } - } - - // Building new projection from BTreeSet - // preserving source projection order if it exists - let projection = if let Some(original_projection) = &scan.projection { - original_projection - .clone() - .into_iter() - .filter(|idx| projection.contains(idx)) - .collect::>() - } else { - projection.into_iter().collect::>() - }; - - // create the projected schema - let projected_fields: Vec = projection - .iter() - .map(|i| { - DFField::from_qualified(scan.table_name.clone(), schema.fields()[*i].clone()) - }) - .collect(); - - let projected_schema = projected_fields.to_dfschema_ref()?; - - Ok(LogicalPlan::TableScan(TableScan { - table_name: scan.table_name.clone(), - source: scan.source.clone(), - projection: Some(projection), - projected_schema, - filters: scan.filters.clone(), - fetch: scan.fetch, - })) -} - -fn restrict_outputs( - plan: Arc, - permitted_outputs: &HashSet, -) -> Result> { - let schema = plan.schema(); - if permitted_outputs.len() == schema.fields().len() { - return Ok(None); - } - Ok(Some(generate_projection( - permitted_outputs, - schema, - plan.clone(), - )?)) -} - #[cfg(test)] mod tests { - use super::*; - use crate::eliminate_project::EliminateProjection; + use std::collections::HashMap; + use std::sync::Arc; + use std::vec; + + use crate::optimize_projections::OptimizeProjections; use crate::optimizer::Optimizer; use crate::test::*; use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::DFSchema; + use datafusion_common::{Column, DFField, DFSchema, Result}; use datafusion_expr::builder::table_scan_with_filters; - use datafusion_expr::expr; - use datafusion_expr::expr::Cast; - use datafusion_expr::WindowFrame; - use datafusion_expr::WindowFunction; + use datafusion_expr::expr::{self, Cast}; + use datafusion_expr::logical_plan::{ + builder::LogicalPlanBuilder, table_scan, JoinType, + }; use datafusion_expr::{ - col, count, lit, - logical_plan::{builder::LogicalPlanBuilder, table_scan, JoinType}, - max, min, AggregateFunction, Expr, + col, count, lit, max, min, AggregateFunction, Expr, LogicalPlan, Projection, + WindowFrame, WindowFunctionDefinition, }; - use std::collections::HashMap; - use std::vec; #[test] fn aggregate_no_group_by() -> Result<()> { @@ -638,6 +101,31 @@ mod tests { assert_optimized_plan_eq(&plan, expected) } + #[test] + fn aggregate_with_periods() -> Result<()> { + let schema = Schema::new(vec![Field::new("tag.one", DataType::Utf8, false)]); + + // Build a plan that looks as follows (note "tag.one" is a column named + // "tag.one", not a column named "one" in a table named "tag"): + // + // Projection: tag.one + // Aggregate: groupBy=[], aggr=[MAX("tag.one") AS "tag.one"] + // TableScan + let plan = table_scan(Some("m4"), &schema, None)? + .aggregate( + Vec::::new(), + vec![max(col(Column::new_unqualified("tag.one"))).alias("tag.one")], + )? + .project([col(Column::new_unqualified("tag.one"))])? + .build()?; + + let expected = "\ + Aggregate: groupBy=[[]], aggr=[[MAX(m4.tag.one) AS tag.one]]\ + \n TableScan: m4 projection=[tag.one]"; + + assert_optimized_plan_eq(&plan, expected) + } + #[test] fn redundant_project() -> Result<()> { let table_scan = test_table_scan()?; @@ -875,7 +363,7 @@ mod tests { // Build the LogicalPlan directly (don't use PlanBuilder), so // that the Column references are unqualified (e.g. their // relation is `None`). PlanBuilder resolves the expressions - let expr = vec![col("a"), col("b")]; + let expr = vec![col("test.a"), col("test.b")]; let plan = LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?); @@ -922,7 +410,7 @@ mod tests { .project(vec![lit(1_i64), lit(2_i64)])? .build()?; let expected = "Projection: Int64(1), Int64(2)\ - \n TableScan: test projection=[a]"; + \n TableScan: test projection=[]"; assert_optimized_plan_eq(&plan, expected) } @@ -969,7 +457,7 @@ mod tests { let expected = "\ Projection: Int32(1) AS a\ - \n TableScan: test projection=[a]"; + \n TableScan: test projection=[]"; assert_optimized_plan_eq(&plan, expected) } @@ -998,7 +486,7 @@ mod tests { let expected = "\ Projection: Int32(1) AS a\ - \n TableScan: test projection=[a], full_filters=[b = Int32(1)]"; + \n TableScan: test projection=[], full_filters=[b = Int32(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1094,7 +582,7 @@ mod tests { let table_scan = test_table_scan()?; let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.a")], vec![col("test.b")], vec![], @@ -1102,7 +590,7 @@ mod tests { )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.b")], vec![], vec![], @@ -1134,24 +622,14 @@ mod tests { } fn optimize(plan: &LogicalPlan) -> Result { - let optimizer = Optimizer::with_rules(vec![ - Arc::new(PushDownProjection::new()), - Arc::new(EliminateProjection::new()), - ]); - let mut optimized_plan = optimizer + let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]); + let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? .unwrap_or_else(|| plan.clone()); - optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.get(1).unwrap(), - &optimized_plan, - &OptimizerContext::new(), - )? - .unwrap_or(optimized_plan); Ok(optimized_plan) } } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index f58d4b159745..187e510e557d 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::optimizer::ApplyOrder; +use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp}; use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::Result; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::Distinct; -use datafusion_expr::{Aggregate, LogicalPlan}; -use ApplyOrder::BottomUp; +use datafusion_expr::{ + aggregate_function::AggregateFunction as AggregateFunctionFunc, col, + expr::AggregateFunction, LogicalPlanBuilder, +}; +use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] /// @@ -33,6 +36,22 @@ use ApplyOrder::BottomUp; /// ```text /// SELECT a, b FROM tab GROUP BY a, b /// ``` +/// +/// On the other hand, for a `DISTINCT ON` query the replacement is +/// a bit more involved and effectively converts +/// ```text +/// SELECT DISTINCT ON (a) b FROM tab ORDER BY a DESC, c +/// ``` +/// +/// into +/// ```text +/// SELECT b FROM ( +/// SELECT a, FIRST_VALUE(b ORDER BY a DESC, c) AS b +/// FROM tab +/// GROUP BY a +/// ) +/// ORDER BY a DESC +/// ``` /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] #[derive(Default)] @@ -52,16 +71,74 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Distinct(Distinct { input }) => { + LogicalPlan::Distinct(Distinct::All(input)) => { let group_expr = expand_wildcard(input.schema(), input, None)?; - let aggregate = LogicalPlan::Aggregate(Aggregate::try_new_with_schema( + let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( input.clone(), group_expr, vec![], - input.schema().clone(), // input schema and aggregate schema are the same in this case )?); Ok(Some(aggregate)) } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + select_expr, + on_expr, + sort_expr, + input, + schema, + })) => { + // Construct the aggregation expression to be used to fetch the selected expressions. + let aggr_expr = select_expr + .iter() + .map(|e| { + Expr::AggregateFunction(AggregateFunction::new( + AggregateFunctionFunc::FirstValue, + vec![e.clone()], + false, + None, + sort_expr.clone(), + )) + }) + .collect::>(); + + // Build the aggregation plan + let plan = LogicalPlanBuilder::from(input.as_ref().clone()) + .aggregate(on_expr.clone(), aggr_expr.to_vec())? + .build()?; + + let plan = if let Some(sort_expr) = sort_expr { + // While sort expressions were used in the `FIRST_VALUE` aggregation itself above, + // this on it's own isn't enough to guarantee the proper output order of the grouping + // (`ON`) expression, so we need to sort those as well. + LogicalPlanBuilder::from(plan) + .sort(sort_expr[..on_expr.len()].to_vec())? + .build()? + } else { + plan + }; + + // Whereas the aggregation plan by default outputs both the grouping and the aggregation + // expressions, for `DISTINCT ON` we only need to emit the original selection expressions. + let project_exprs = plan + .schema() + .fields() + .iter() + .skip(on_expr.len()) + .zip(schema.fields().iter()) + .map(|(new_field, old_field)| { + Ok(col(new_field.qualified_column()).alias_qualified( + old_field.qualifier().cloned(), + old_field.name(), + )) + }) + .collect::>>()?; + + let plan = LogicalPlanBuilder::from(plan) + .project(project_exprs)? + .build()?; + + Ok(Some(plan)) + } _ => Ok(None), } } @@ -100,4 +177,27 @@ mod tests { expected, ) } + + #[test] + fn replace_distinct_on() -> datafusion_common::Result<()> { + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .distinct_on( + vec![col("a")], + vec![col("b")], + Some(vec![col("a").sort(false, true), col("c").sort(true, false)]), + )? + .build()?; + + let expected = "Projection: FIRST_VALUE(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST] AS b\ + \n Sort: test.a DESC NULLS FIRST\ + \n Aggregate: groupBy=[[test.a]], aggr=[[FIRST_VALUE(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST]]]\ + \n TableScan: test"; + + assert_optimized_plan_eq( + Arc::new(ReplaceDistinctWithAggregate::new()), + &plan, + expected, + ) + } } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 96d2f45d808e..34ed4a9475cb 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -17,7 +17,7 @@ use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, replace_qualified_name}; +use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ @@ -26,6 +26,7 @@ use datafusion_common::tree_node::{ use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; +use datafusion_expr::utils::conjunction; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; @@ -315,24 +316,14 @@ fn build_join( _ => { // if not correlated, group down to 1 row and left join on that (preserving row count) LogicalPlanBuilder::from(filter_input.clone()) - .join( - sub_query_alias, - JoinType::Left, - (Vec::::new(), Vec::::new()), - None, - )? + .join_on(sub_query_alias, JoinType::Left, None)? .build()? } } } else { // left join if correlated, grouping by the join keys so we don't change row count LogicalPlanBuilder::from(filter_input.clone()) - .join( - sub_query_alias, - JoinType::Left, - (Vec::::new(), Vec::::new()), - join_filter_opt, - )? + .join_on(sub_query_alias, JoinType::Left, join_filter_opt)? .build()? }; let mut computation_project_expr = HashMap::new(); diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index f5a6860299ab..7d09aec7e748 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -21,31 +21,33 @@ use std::ops::Not; use super::or_in_list_simplifier::OrInListSimplifier; use super::utils::*; - use crate::analyzer::type_coercion::TypeCoercionRewriter; +use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; +use crate::simplify_expressions::SimplifyInfo; + use arrow::{ array::new_null_array, datatypes::{DataType, Field, Schema}, - error::ArrowError, record_batch::RecordBatch, }; -use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; +use datafusion_common::{ + cast::{as_large_list_array, as_list_array}, + plan_err, + tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}, +}; use datafusion_common::{ exec_err, internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::{InList, InSubquery, ScalarFunction}; use datafusion_expr::{ - and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, - Like, Volatility, + and, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, Like, + ScalarFunctionDefinition, Volatility, }; -use datafusion_physical_expr::{ - create_physical_expr, execution_props::ExecutionProps, intervals::NullableInterval, +use datafusion_expr::{ + expr::{InList, InSubquery, ScalarFunction}, + interval_arithmetic::NullableInterval, }; - -use crate::simplify_expressions::SimplifyInfo; - -use crate::simplify_expressions::guarantees::GuaranteeRewriter; +use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; /// This structure handles API for expression simplification pub struct ExprSimplifier { @@ -175,9 +177,9 @@ impl ExprSimplifier { /// ```rust /// use arrow::datatypes::{DataType, Field, Schema}; /// use datafusion_expr::{col, lit, Expr}; + /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; /// use datafusion_physical_expr::execution_props::ExecutionProps; - /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; /// use datafusion_optimizer::simplify_expressions::{ /// ExprSimplifier, SimplifyContext}; /// @@ -204,7 +206,7 @@ impl ExprSimplifier { /// ( /// col("x"), /// NullableInterval::NotNull { - /// values: Interval::make(Some(3_i64), Some(5_i64), (false, false)), + /// values: Interval::make(Some(3_i64), Some(5_i64)).unwrap() /// } /// ), /// // y = 3 @@ -330,7 +332,6 @@ impl<'a> ConstEvaluator<'a> { // Has no runtime cost, but needed during planning Expr::Alias(..) | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } | Expr::ScalarVariable(_, _) | Expr::Column(_) | Expr::OuterReferenceColumn(_, _) @@ -340,15 +341,17 @@ impl<'a> ConstEvaluator<'a> { | Expr::WindowFunction { .. } | Expr::Sort { .. } | Expr::GroupingSet(_) - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard { .. } | Expr::Placeholder(_) => false, - Expr::ScalarFunction(ScalarFunction { fun, .. }) => { - Self::volatility_ok(fun.volatility()) - } - Expr::ScalarUDF(expr::ScalarUDF { fun, .. }) => { - Self::volatility_ok(fun.signature.volatility) - } + Expr::ScalarFunction(ScalarFunction { func_def, .. }) => match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + Self::volatility_ok(fun.volatility()) + } + ScalarFunctionDefinition::UDF(fun) => { + Self::volatility_ok(fun.signature().volatility) + } + ScalarFunctionDefinition::Name(_) => false, + }, Expr::Literal(_) | Expr::BinaryExpr { .. } | Expr::Not(_) @@ -392,8 +395,11 @@ impl<'a> ConstEvaluator<'a> { "Could not evaluate the expression, found a result of length {}", a.len() ) + } else if as_list_array(&a).is_ok() || as_large_list_array(&a).is_ok() { + Ok(ScalarValue::List(a)) } else { - Ok(ScalarValue::try_from_array(&a, 0)?) + // Non-ListArray + ScalarValue::try_from_array(&a, 0) } } ColumnarValue::Scalar(s) => Ok(s), @@ -475,6 +481,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { lit(negated) } + // null in (x, y, z) --> null + // null not in (x, y, z) --> null + Expr::InList(InList { + expr, + list: _, + negated: _, + }) if is_null(&expr) => lit_bool_null(), + // expr IN ((subquery)) -> expr IN (subquery), see ##5529 Expr::InList(InList { expr, @@ -786,7 +800,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Divide, right, }) if is_null(&right) => *right, - // A / 0 -> DivideByZero Error if A is not null and not floating + // A / 0 -> Divide by zero error if A is not null and not floating // (float / 0 -> inf | -inf | NAN) Expr::BinaryExpr(BinaryExpr { left, @@ -796,7 +810,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&left)?.is_floating() && is_zero(&right) => { - return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)); + return plan_err!("Divide by zero"); } // @@ -826,7 +840,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { { lit(0) } - // A % 0 --> DivideByZero Error (if A is not floating and not null) + // A % 0 --> Divide by zero Error (if A is not floating and not null) // A % 0 --> NAN (if A is floating and not null) Expr::BinaryExpr(BinaryExpr { left, @@ -837,9 +851,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { DataType::Float32 => lit(f32::NAN), DataType::Float64 => lit(f64::NAN), _ => { - return Err(DataFusionError::ArrowError( - ArrowError::DivideByZero, - )); + return plan_err!("Divide by zero"); } } } @@ -1196,25 +1208,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // log Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::Log, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Log), args, }) => simpl_log(args, <&S>::clone(&info))?, // power Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::Power, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Power), args, }) => simpl_power(args, <&S>::clone(&info))?, // concat Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::Concat, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Concat), args, }) => simpl_concat(args)?, // concat_ws Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::ConcatWithSeparator, + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::ConcatWithSeparator, + ), args, }) => match &args[..] { [delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?, @@ -1295,26 +1310,27 @@ mod tests { sync::Arc, }; + use super::*; use crate::simplify_expressions::{ utils::for_test::{cast_to_int64_expr, now_expr, to_timestamp_expr}, SimplifyContext, }; - - use super::*; use crate::test::test_table_scan_with_name; + use arrow::{ array::{ArrayRef, Int32Array}, datatypes::{DataType, Field, Schema}, }; - use chrono::{DateTime, TimeZone, Utc}; - use datafusion_common::{assert_contains, cast::as_int32_array, DFField, ToDFSchema}; - use datafusion_expr::*; + use datafusion_common::{ + assert_contains, cast::as_int32_array, plan_datafusion_err, DFField, ToDFSchema, + }; + use datafusion_expr::{interval_arithmetic::Interval, *}; use datafusion_physical_expr::{ - execution_props::ExecutionProps, - functions::make_scalar_function, - intervals::{Interval, NullableInterval}, + execution_props::ExecutionProps, functions::make_scalar_function, }; + use chrono::{DateTime, TimeZone, Utc}; + // ------------------------------ // --- ExprSimplifier tests ----- // ------------------------------ @@ -1495,7 +1511,7 @@ mod tests { test_evaluate(expr, lit("foobarbaz")); // Check non string arguments - // to_timestamp("2020-09-08T12:00:00+00:00") --> timestamp(1599566400000000000i64) + // to_timestamp("2020-09-08T12:00:00+00:00") --> timestamp(1599566400i64) let expr = call_fn("to_timestamp", vec![lit("2020-09-08T12:00:00+00:00")]).unwrap(); test_evaluate(expr, lit_timestamp_nano(1599566400000000000i64)); @@ -1547,7 +1563,7 @@ mod tests { // immutable UDF should get folded // udf_add(1+2, 30+40) --> 73 - let expr = Expr::ScalarUDF(expr::ScalarUDF::new( + let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( make_udf_add(Volatility::Immutable), args.clone(), )); @@ -1556,15 +1572,21 @@ mod tests { // stable UDF should be entirely folded // udf_add(1+2, 30+40) --> 73 let fun = make_udf_add(Volatility::Stable); - let expr = Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), args.clone())); + let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( + Arc::clone(&fun), + args.clone(), + )); test_evaluate(expr, lit(73)); // volatile UDF should have args folded // udf_add(1+2, 30+40) --> udf_add(3, 70) let fun = make_udf_add(Volatility::Volatile); - let expr = Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), args)); - let expected_expr = - Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), folded_args)); + let expr = + Expr::ScalarFunction(expr::ScalarFunction::new_udf(Arc::clone(&fun), args)); + let expected_expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( + Arc::clone(&fun), + folded_args, + )); test_evaluate(expr, expected_expr); } @@ -1757,25 +1779,23 @@ mod tests { #[test] fn test_simplify_divide_zero_by_zero() { - // 0 / 0 -> DivideByZero + // 0 / 0 -> Divide by zero let expr = lit(0) / lit(0); let err = try_simplify(expr).unwrap_err(); - assert!( - matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)), - "{err}" - ); + let _expected = plan_datafusion_err!("Divide by zero"); + + assert!(matches!(err, ref _expected), "{err}"); } #[test] - #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: ArrowError(DivideByZero)" - )] fn test_simplify_divide_by_zero() { // A / 0 -> DivideByZeroError let expr = col("c2_non_null") / lit(0); - - simplify(expr); + assert_eq!( + try_simplify(expr).unwrap_err().strip_backtrace(), + "Error during planning: Divide by zero" + ); } #[test] @@ -2195,12 +2215,12 @@ mod tests { } #[test] - #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: ArrowError(DivideByZero)" - )] fn test_simplify_modulo_by_zero_non_null() { let expr = col("c2_non_null") % lit(0); - simplify(expr); + assert_eq!( + try_simplify(expr).unwrap_err().strip_backtrace(), + "Error during planning: Divide by zero" + ); } #[test] @@ -3084,6 +3104,18 @@ mod tests { assert_eq!(simplify(in_list(col("c1"), vec![], false)), lit(false)); assert_eq!(simplify(in_list(col("c1"), vec![], true)), lit(true)); + // null in (...) --> null + assert_eq!( + simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], false)), + lit_bool_null() + ); + + // null not in (...) --> null + assert_eq!( + simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], true)), + lit_bool_null() + ); + assert_eq!( simplify(in_list(col("c1"), vec![lit(1)], false)), col("c1").eq(lit(1)) @@ -3276,17 +3308,14 @@ mod tests { ( col("c3"), NullableInterval::NotNull { - values: Interval::make(Some(0_i64), Some(2_i64), (false, false)), + values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(), }, ), ( col("c4"), NullableInterval::from(ScalarValue::UInt32(Some(9))), ), - ( - col("c1"), - NullableInterval::from(ScalarValue::Utf8(Some("a".to_string()))), - ), + (col("c1"), NullableInterval::from(ScalarValue::from("a"))), ]; let output = simplify_with_guarantee(expr.clone(), guarantees); assert_eq!(output, lit(false)); @@ -3296,19 +3325,23 @@ mod tests { ( col("c3"), NullableInterval::MaybeNull { - values: Interval::make(Some(0_i64), Some(2_i64), (false, false)), + values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(), }, ), ( col("c4"), NullableInterval::MaybeNull { - values: Interval::make(Some(9_u32), Some(9_u32), (false, false)), + values: Interval::make(Some(9_u32), Some(9_u32)).unwrap(), }, ), ( col("c1"), NullableInterval::NotNull { - values: Interval::make(Some("d"), Some("f"), (false, false)), + values: Interval::try_new( + ScalarValue::from("d"), + ScalarValue::from("f"), + ) + .unwrap(), }, ), ]; diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 5504d7d76e35..aa7bb4f78a93 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -18,11 +18,12 @@ //! Simplifier implementation for [`ExprSimplifier::with_guarantees()`] //! //! [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees + +use std::{borrow::Cow, collections::HashMap}; + use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; -use std::collections::HashMap; - -use datafusion_physical_expr::intervals::{Interval, IntervalBound, NullableInterval}; /// Rewrite expressions to incorporate guarantees. /// @@ -46,6 +47,10 @@ impl<'a> GuaranteeRewriter<'a> { guarantees: impl IntoIterator, ) -> Self { Self { + // TODO: Clippy wants the "map" call removed, but doing so generates + // a compilation error. Remove the clippy directive once this + // issue is fixed. + #[allow(clippy::map_identity)] guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), } } @@ -82,10 +87,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { high.as_ref(), ) { let expr_interval = NullableInterval::NotNull { - values: Interval::new( - IntervalBound::new(low.clone(), false), - IntervalBound::new(high.clone(), false), - ), + values: Interval::try_new(low.clone(), high.clone())?, }; let contains = expr_interval.contains(*interval)?; @@ -103,48 +105,51 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - // We only support comparisons for now - if !op.is_comparison_operator() { - return Ok(expr); - }; - - // Check if this is a comparison between a column and literal - let (col, op, value) = match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(value)) => (left, *op, value), - (Expr::Literal(value), Expr::Column(_)) => { - // If we can swap the op, we can simplify the expression - if let Some(op) = op.swap() { - (right, op, value) + // The left or right side of expression might either have a guarantee + // or be a literal. Either way, we can resolve them to a NullableInterval. + let left_interval = self + .guarantees + .get(left.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value) = left.as_ref() { + Some(Cow::Owned(value.clone().into())) } else { - return Ok(expr); + None + } + }); + let right_interval = self + .guarantees + .get(right.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value) = right.as_ref() { + Some(Cow::Owned(value.clone().into())) + } else { + None + } + }); + + match (left_interval, right_interval) { + (Some(left_interval), Some(right_interval)) => { + let result = + left_interval.apply_operator(op, right_interval.as_ref())?; + if result.is_certainly_true() { + Ok(lit(true)) + } else if result.is_certainly_false() { + Ok(lit(false)) + } else { + Ok(expr) } } - _ => return Ok(expr), - }; - - if let Some(col_interval) = self.guarantees.get(col.as_ref()) { - let result = - col_interval.apply_operator(&op, &value.clone().into())?; - if result.is_certainly_true() { - Ok(lit(true)) - } else if result.is_certainly_false() { - Ok(lit(false)) - } else { - Ok(expr) - } - } else { - Ok(expr) + _ => Ok(expr), } } // Columns (if interval is collapsed to a single value) Expr::Column(_) => { - if let Some(col_interval) = self.guarantees.get(&expr) { - if let Some(value) = col_interval.single_value() { - Ok(lit(value)) - } else { - Ok(expr) - } + if let Some(interval) = self.guarantees.get(&expr) { + Ok(interval.single_value().map_or(expr, lit)) } else { Ok(expr) } @@ -208,7 +213,7 @@ mod tests { ( col("x"), NullableInterval::NotNull { - values: Default::default(), + values: Interval::make_unbounded(&DataType::Boolean).unwrap(), }, ), ]; @@ -255,11 +260,18 @@ mod tests { #[test] fn test_inequalities_non_null_bounded() { let guarantees = vec![ - // x ∈ (1, 3] (not null) + // x ∈ [1, 3] (not null) ( col("x"), NullableInterval::NotNull { - values: Interval::make(Some(1_i32), Some(3_i32), (true, false)), + values: Interval::make(Some(1_i32), Some(3_i32)).unwrap(), + }, + ), + // s.y ∈ [1, 3] (not null) + ( + col("s").field("y"), + NullableInterval::NotNull { + values: Interval::make(Some(1_i32), Some(3_i32)).unwrap(), }, ), ]; @@ -268,17 +280,16 @@ mod tests { // (original_expr, expected_simplification) let simplified_cases = &[ - (col("x").lt_eq(lit(1)), false), + (col("x").lt(lit(0)), false), + (col("s").field("y").lt(lit(0)), false), (col("x").lt_eq(lit(3)), true), (col("x").gt(lit(3)), false), - (col("x").gt(lit(1)), true), + (col("x").gt(lit(0)), true), (col("x").eq(lit(0)), false), (col("x").not_eq(lit(0)), true), - (col("x").between(lit(2), lit(5)), true), - (col("x").between(lit(2), lit(3)), true), + (col("x").between(lit(0), lit(5)), true), (col("x").between(lit(5), lit(10)), false), - (col("x").not_between(lit(2), lit(5)), false), - (col("x").not_between(lit(2), lit(3)), false), + (col("x").not_between(lit(0), lit(5)), false), (col("x").not_between(lit(5), lit(10)), true), ( Expr::BinaryExpr(BinaryExpr { @@ -319,10 +330,11 @@ mod tests { ( col("x"), NullableInterval::NotNull { - values: Interval::new( - IntervalBound::new(ScalarValue::Date32(Some(18628)), false), - IntervalBound::make_unbounded(DataType::Date32).unwrap(), - ), + values: Interval::try_new( + ScalarValue::Date32(Some(18628)), + ScalarValue::Date32(None), + ) + .unwrap(), }, ), ]; @@ -397,7 +409,11 @@ mod tests { ( col("x"), NullableInterval::MaybeNull { - values: Interval::make(Some("abc"), Some("def"), (true, false)), + values: Interval::try_new( + ScalarValue::from("abc"), + ScalarValue::from("def"), + ) + .unwrap(), }, ), ]; @@ -451,7 +467,7 @@ mod tests { ScalarValue::Int32(Some(1)), ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(None), - ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::from("abc"), ScalarValue::LargeUtf8(Some("def".to_string())), ScalarValue::Date32(Some(18628)), ScalarValue::Date32(None), @@ -470,11 +486,15 @@ mod tests { #[test] fn test_in_list() { let guarantees = vec![ - // x ∈ [1, 10) (not null) + // x ∈ [1, 10] (not null) ( col("x"), NullableInterval::NotNull { - values: Interval::make(Some(1_i32), Some(10_i32), (false, true)), + values: Interval::try_new( + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(10)), + ) + .unwrap(), }, ), ]; @@ -486,8 +506,8 @@ mod tests { let cases = &[ // x IN (9, 11) => x IN (9) ("x", vec![9, 11], false, vec![9]), - // x IN (10, 2) => x IN (2) - ("x", vec![10, 2], false, vec![2]), + // x IN (10, 2) => x IN (10, 2) + ("x", vec![10, 2], false, vec![10, 2]), // x NOT IN (9, 11) => x NOT IN (9) ("x", vec![9, 11], true, vec![9]), // x NOT IN (0, 22) => x NOT IN () diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index b9d9821b43f0..175b70f2b10e 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -84,7 +84,7 @@ impl OperatorMode { let like = Like { negated: self.not, expr, - pattern: Box::new(Expr::Literal(ScalarValue::Utf8(Some(pattern)))), + pattern: Box::new(Expr::Literal(ScalarValue::from(pattern))), escape_char: None, case_insensitive: self.i, }; diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 35a698b709ac..43a41b1185a3 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -20,10 +20,10 @@ use std::sync::Arc; use super::{ExprSimplifier, SimplifyContext}; -use crate::utils::merge_schema; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DFSchema, DFSchemaRef, Result}; use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::utils::merge_schema; use datafusion_physical_expr::execution_props::ExecutionProps; /// Optimizer Pass that simplifies [`LogicalPlan`]s by rewriting @@ -65,10 +65,21 @@ impl SimplifyExpressions { ) -> Result { let schema = if !plan.inputs().is_empty() { DFSchemaRef::new(merge_schema(plan.inputs())) - } else if let LogicalPlan::TableScan(_) = plan { - // When predicates are pushed into a table scan, there needs to be - // a schema to resolve the fields against. - Arc::clone(plan.schema()) + } else if let LogicalPlan::TableScan(scan) = plan { + // When predicates are pushed into a table scan, there is no input + // schema to resolve predicates against, so it must be handled specially + // + // Note that this is not `plan.schema()` which is the *output* + // schema, and reflects any pushed down projection. The output schema + // will not contain columns that *only* appear in pushed down predicates + // (and no where else) in the plan. + // + // Thus, use the full schema of the inner provider without any + // projection applied for simplification + Arc::new(DFSchema::try_from_qualified_schema( + &scan.table_name, + &scan.source.schema(), + )?) } else { Arc::new(DFSchema::empty()) }; @@ -111,7 +122,7 @@ mod tests { use crate::simplify_expressions::utils::for_test::{ cast_to_int64_expr, now_expr, to_timestamp_expr, }; - use crate::test::test_table_scan_with_name; + use crate::test::{assert_fields_eq, test_table_scan_with_name}; use super::*; use arrow::datatypes::{DataType, Field, Schema}; @@ -174,6 +185,48 @@ mod tests { Ok(()) } + #[test] + fn test_simplify_table_full_filter_in_scan() -> Result<()> { + let fields = vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::UInt32, false), + Field::new("c", DataType::UInt32, false), + ]; + + let schema = Schema::new(fields); + + let table_scan = table_scan_with_filters( + Some("test"), + &schema, + Some(vec![0]), + vec![col("b").is_not_null()], + )? + .build()?; + assert_eq!(1, table_scan.schema().fields().len()); + assert_fields_eq(&table_scan, vec!["a"]); + + let expected = "TableScan: test projection=[a], full_filters=[Boolean(true) AS b IS NOT NULL]"; + + assert_optimized_plan_eq(&table_scan, expected) + } + + #[test] + fn test_simplify_filter_pushdown() -> Result<()> { + let table_scan = test_table_scan(); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a")])? + .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))? + .build()?; + + assert_optimized_plan_eq( + &plan, + "\ + Filter: test.b > Int32(1)\ + \n Projection: test.a\ + \n TableScan: test", + ) + } + #[test] fn test_simplify_optimized_plan() -> Result<()> { let table_scan = test_table_scan(); diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 28c61427c5ef..fa91a3ace2a2 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -23,7 +23,7 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ expr::{Between, BinaryExpr, InList}, expr_fn::{and, bitwise_and, bitwise_or, concat_ws, or}, - lit, BuiltinScalarFunction, Expr, Like, Operator, + lit, BuiltinScalarFunction, Expr, Like, Operator, ScalarFunctionDefinition, }; pub static POWS_OF_TEN: [i128; 38] = [ @@ -365,7 +365,7 @@ pub fn simpl_log(current_args: Vec, info: &dyn SimplifyInfo) -> Result Ok(args[1].clone()), _ => { @@ -405,7 +405,7 @@ pub fn simpl_power(current_args: Vec, info: &dyn SimplifyInfo) -> Result Ok(args[1].clone()), _ => Ok(Expr::ScalarFunction(ScalarFunction::new( @@ -525,8 +525,8 @@ pub fn simpl_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { d => Ok(concat_ws( d.clone(), args.iter() + .filter(|&x| !is_null(x)) .cloned() - .filter(|x| !is_null(x)) .collect::>(), )), } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index a9e65b3e7c77..7e6fb6b355ab 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -17,32 +17,39 @@ //! single distinct to group by optimizer rule +use std::sync::Arc; + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::{DFSchema, Result}; +use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::{ + aggregate_function::AggregateFunction::{Max, Min, Sum}, col, expr::AggregateFunction, logical_plan::{Aggregate, LogicalPlan, Projection}, utils::columnize_expr, Expr, ExprSchemable, }; + use hashbrown::HashSet; -use std::sync::Arc; /// single distinct to group by optimizer rule /// ```text -/// SELECT F1(DISTINCT s),F2(DISTINCT s) -/// ... -/// GROUP BY k -/// -/// Into +/// Before: +/// SELECT a, COUNT(DINSTINCT b), SUM(c) +/// FROM t +/// GROUP BY a /// -/// SELECT F1(alias1),F2(alias1) +/// After: +/// SELECT a, COUNT(alias1), SUM(alias2) /// FROM ( -/// SELECT s as alias1, k ... GROUP BY s, k +/// SELECT a, b as alias1, SUM(c) as alias2 +/// FROM t +/// GROUP BY a, b /// ) -/// GROUP BY k +/// GROUP BY a /// ``` #[derive(Default)] pub struct SingleDistinctToGroupBy {} @@ -61,22 +68,30 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => { let mut fields_set = HashSet::new(); - let mut distinct_count = 0; + let mut aggregate_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { - distinct, args, .. + func_def: AggregateFunctionDefinition::BuiltIn(fun), + distinct, + args, + filter, + order_by, }) = expr { - if *distinct { - distinct_count += 1; + if filter.is_some() || order_by.is_some() { + return Ok(false); } - for e in args { - fields_set.insert(e.display_name()?); + aggregate_count += 1; + if *distinct { + for e in args { + fields_set.insert(e.canonical_name()); + } + } else if !matches!(fun, Sum | Min | Max) { + return Ok(false); } } } - let res = distinct_count == aggr_expr.len() && fields_set.len() == 1; - Ok(res) + Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1) } _ => Ok(false), } @@ -102,51 +117,104 @@ impl OptimizerRule for SingleDistinctToGroupBy { .. }) => { if is_single_distinct_agg(plan)? && !contains_grouping_set(group_expr) { + let fields = schema.fields(); // alias all original group_by exprs - let mut group_expr_alias = Vec::with_capacity(group_expr.len()); - let mut inner_group_exprs = group_expr + let (mut inner_group_exprs, out_group_expr_with_alias): ( + Vec, + Vec<(Expr, Option)>, + ) = group_expr .iter() .enumerate() .map(|(i, group_expr)| { - let alias_str = format!("group_alias_{i}"); - let alias_expr = group_expr.clone().alias(&alias_str); - group_expr_alias - .push((alias_str, schema.fields()[i].clone())); - alias_expr + if let Expr::Column(_) = group_expr { + // For Column expressions we can use existing expression as is. + (group_expr.clone(), (group_expr.clone(), None)) + } else { + // For complex expression write is as alias, to be able to refer + // if from parent operators successfully. + // Consider plan below. + // + // Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ + // --Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ + // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] + // + // First aggregate(from bottom) refers to `test.a` column. + // Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate. + // If we were to write plan above as below without alias + // + // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ + // --Aggregate: groupBy=[[test.a + Int32(1), test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ + // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] + // + // Second aggregate refers to the `test.a + Int32(1)` expression However, its input do not have `test.a` expression in it. + let alias_str = format!("group_alias_{i}"); + let alias_expr = group_expr.clone().alias(&alias_str); + ( + alias_expr, + (col(alias_str), Some(fields[i].qualified_name())), + ) + } }) - .collect::>(); + .unzip(); // and they can be referenced by the alias in the outer aggr plan - let outer_group_exprs = group_expr_alias + let outer_group_exprs = out_group_expr_with_alias .iter() - .map(|(alias, _)| col(alias)) + .map(|(out_group_expr, _)| out_group_expr.clone()) .collect::>(); // replace the distinct arg with alias + let mut index = 1; let mut group_fields_set = HashSet::new(); - let new_aggr_exprs = aggr_expr + let mut inner_aggr_exprs = vec![]; + let outer_aggr_exprs = aggr_expr .iter() .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction(AggregateFunction { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), args, - filter, - order_by, + distinct, .. }) => { // is_single_distinct_agg ensure args.len=1 - if group_fields_set.insert(args[0].display_name()?) { + if *distinct + && group_fields_set.insert(args[0].display_name()?) + { inner_group_exprs.push( args[0].clone().alias(SINGLE_DISTINCT_ALIAS), ); } - Ok(Expr::AggregateFunction(AggregateFunction::new( - fun.clone(), - vec![col(SINGLE_DISTINCT_ALIAS)], - false, // intentional to remove distinct here - filter.clone(), - order_by.clone(), - ))) + + // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation + if !(*distinct) { + index += 1; + let alias_str = format!("alias{}", index); + inner_aggr_exprs.push( + Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + args.clone(), + false, + None, + None, + )) + .alias(&alias_str), + ); + Ok(Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + vec![col(&alias_str)], + false, + None, + None, + ))) + } else { + Ok(Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + vec![col(SINGLE_DISTINCT_ALIAS)], + false, // intentional to remove distinct here + None, + None, + ))) + } } _ => Ok(aggr_expr.clone()), }) @@ -155,6 +223,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { // construct the inner AggrPlan let inner_fields = inner_group_exprs .iter() + .chain(inner_aggr_exprs.iter()) .map(|expr| expr.to_field(input.schema())) .collect::>>()?; let inner_schema = DFSchema::new_with_metadata( @@ -164,12 +233,12 @@ impl OptimizerRule for SingleDistinctToGroupBy { let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new( input.clone(), inner_group_exprs, - Vec::new(), + inner_aggr_exprs, )?); let outer_fields = outer_group_exprs .iter() - .chain(new_aggr_exprs.iter()) + .chain(outer_aggr_exprs.iter()) .map(|expr| expr.to_field(&inner_schema)) .collect::>>()?; let outer_aggr_schema = Arc::new(DFSchema::new_with_metadata( @@ -181,34 +250,33 @@ impl OptimizerRule for SingleDistinctToGroupBy { // this optimizer has two kinds of alias: // - group_by aggr // - aggr expr - let mut alias_expr: Vec = Vec::new(); - for (alias, original_field) in group_expr_alias { - alias_expr - .push(col(alias).alias(original_field.qualified_name())); - } - for (i, expr) in new_aggr_exprs.iter().enumerate() { - alias_expr.push(columnize_expr( - expr.clone().alias( - schema.clone().fields()[i + group_expr.len()] - .qualified_name(), - ), - &outer_aggr_schema, - )); - } + let group_size = group_expr.len(); + let alias_expr = out_group_expr_with_alias + .into_iter() + .map(|(group_expr, original_field)| { + if let Some(name) = original_field { + group_expr.alias(name) + } else { + group_expr + } + }) + .chain(outer_aggr_exprs.iter().enumerate().map(|(idx, expr)| { + let idx = idx + group_size; + let name = fields[idx].qualified_name(); + columnize_expr(expr.clone().alias(name), &outer_aggr_schema) + })) + .collect(); let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new( Arc::new(inner_agg), outer_group_exprs, - new_aggr_exprs, + outer_aggr_exprs, )?); - Ok(Some(LogicalPlan::Projection( - Projection::try_new_with_schema( - alias_expr, - Arc::new(outer_aggr), - schema.clone(), - )?, - ))) + Ok(Some(LogicalPlan::Projection(Projection::try_new( + alias_expr, + Arc::new(outer_aggr), + )?))) } else { Ok(None) } @@ -234,7 +302,7 @@ mod tests { use datafusion_expr::expr::GroupingSet; use datafusion_expr::{ col, count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, - AggregateFunction, + min, sum, AggregateFunction, }; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -294,7 +362,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -312,7 +380,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -331,7 +399,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -362,9 +430,9 @@ mod tests { .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a, COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:UInt32, COUNT(alias1):Int64;N]\ - \n Aggregate: groupBy=[[test.a AS group_alias_0, test.b AS alias1]], aggr=[[]] [group_alias_0:UInt32, alias1:UInt32]\ + let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1)]] [a:UInt32, COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -408,9 +476,9 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a, COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1), MAX(alias1)]] [group_alias_0:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ - \n Aggregate: groupBy=[[test.a AS group_alias_0, test.b AS alias1]], aggr=[[]] [group_alias_0:UInt32, alias1:UInt32]\ + let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1), MAX(alias1)]] [a:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -450,4 +518,181 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + + #[test] + fn two_distinct_and_one_common() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![ + sum(col("c")), + count_distinct(col("b")), + Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Max, + vec![col("b")], + true, + None, + None, + )), + ], + )? + .build()?; + // Should work + let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), COUNT(alias1), MAX(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn one_distinctand_and_two_common() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![sum(col("c")), max(col("c")), count_distinct(col("b"))], + )? + .build()?; + // Should work + let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), MAX(alias3) AS MAX(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, MAX(test.c):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), MAX(alias3), COUNT(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, MAX(alias3):UInt32;N, COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn one_distinct_and_one_common() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("c")], + vec![min(col("a")), count_distinct(col("b"))], + )? + .build()?; + // Should work + let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), COUNT(alias1) AS COUNT(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), COUNT(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn common_with_filter() -> Result<()> { + let table_scan = test_table_scan()?; + + // SUM(a) FILTER (WHERE a > 5) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Sum, + vec![col("a")], + false, + Some(Box::new(col("a").gt(lit(5)))), + None, + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn distinct_with_filter() -> Result<()> { + let table_scan = test_table_scan()?; + + // COUNT(DISTINCT a) FILTER (WHERE a > 5) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Count, + vec![col("a")], + true, + Some(Box::new(col("a").gt(lit(5)))), + None, + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![sum(col("a")), expr])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn common_with_order_by() -> Result<()> { + let table_scan = test_table_scan()?; + + // SUM(a ORDER BY a) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Sum, + vec![col("a")], + false, + None, + Some(vec![col("a")]), + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY [test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn distinct_with_order_by() -> Result<()> { + let table_scan = test_table_scan()?; + + // COUNT(DISTINCT a ORDER BY a) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Count, + vec![col("a")], + true, + None, + Some(vec![col("a")]), + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![sum(col("a")), expr])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn aggregate_with_filter_and_order_by() -> Result<()> { + let table_scan = test_table_scan()?; + + // COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Count, + vec![col("a")], + true, + Some(Box::new(col("a").gt(lit(5)))), + Some(vec![col("a")]), + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![sum(col("a")), expr])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 7d334a80b682..e691fe9a5351 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -16,7 +16,7 @@ // under the License. use crate::analyzer::{Analyzer, AnalyzerRule}; -use crate::optimizer::Optimizer; +use crate::optimizer::{assert_schema_is_the_same, Optimizer}; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; @@ -155,20 +155,42 @@ pub fn assert_optimized_plan_eq( plan: &LogicalPlan, expected: &str, ) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![rule]); + let optimizer = Optimizer::with_rules(vec![rule.clone()]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? .unwrap_or_else(|| plan.clone()); + + // Ensure schemas always match after an optimization + assert_schema_is_the_same(rule.name(), plan, &optimized_plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); Ok(()) } +pub fn assert_optimized_plan_eq_with_rules( + rules: Vec>, + plan: &LogicalPlan, + expected: &str, +) -> Result<()> { + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + let config = &mut OptimizerContext::new() + .with_max_passes(1) + .with_skip_failing_rules(false); + let optimizer = Optimizer::with_rules(rules); + let optimized_plan = optimizer + .optimize(plan, config, observe) + .expect("failed to optimize plan"); + let formatted_plan = format!("{optimized_plan:?}"); + assert_eq!(formatted_plan, expected); + assert_eq!(plan.schema(), optimized_plan.schema()); + Ok(()) +} + pub fn assert_optimized_plan_eq_display_indent( rule: Arc, plan: &LogicalPlan, @@ -177,7 +199,7 @@ pub fn assert_optimized_plan_eq_display_indent( let optimizer = Optimizer::with_rules(vec![rule]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), ) @@ -211,7 +233,7 @@ pub fn assert_optimizer_err( ) { let optimizer = Optimizer::with_rules(vec![rule]); let res = optimizer.optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), ); @@ -233,7 +255,7 @@ pub fn assert_optimization_skipped( let optimizer = Optimizer::with_rules(vec![rule]); let new_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 468981a5fb0c..91603e82a54f 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -19,7 +19,6 @@ //! of expr can be added if needed. //! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. use crate::optimizer::ApplyOrder; -use crate::utils::merge_schema; use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, @@ -31,6 +30,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; +use datafusion_expr::utils::merge_schema; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, }; @@ -1089,8 +1089,12 @@ mod tests { // Verify that calling the arrow // cast kernel yields the same results // input array - let literal_array = literal.to_array_of_size(1); - let expected_array = expected_value.to_array_of_size(1); + let literal_array = literal + .to_array_of_size(1) + .expect("Failed to convert to array of size"); + let expected_array = expected_value + .to_array_of_size(1) + .expect("Failed to convert to array of size"); let cast_array = cast_with_options( &literal_array, &target_type, diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index a3e7e42875d7..44f2404afade 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -18,19 +18,13 @@ //! Collection of utility functions that are leveraged by the query optimizer rules use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::DataFusionError; -use datafusion_common::{plan_err, Column, DFSchemaRef}; +use datafusion_common::{Column, DFSchemaRef}; use datafusion_common::{DFSchema, Result}; -use datafusion_expr::expr::{Alias, BinaryExpr}; -use datafusion_expr::expr_rewriter::{replace_col, strip_outer_reference}; -use datafusion_expr::{ - and, - logical_plan::{Filter, LogicalPlan}, - Expr, Operator, -}; +use datafusion_expr::expr_rewriter::replace_col; +use datafusion_expr::utils as expr_utils; +use datafusion_expr::{logical_plan::LogicalPlan, Expr, Operator}; use log::{debug, trace}; use std::collections::{BTreeSet, HashMap}; -use std::sync::Arc; /// Convenience rule for writing optimizers: recursively invoke /// optimize on plan's children and then return a node of the same @@ -52,35 +46,61 @@ pub fn optimize_children( new_inputs.push(new_input.unwrap_or_else(|| input.clone())) } if plan_is_changed { - Ok(Some(plan.with_new_inputs(&new_inputs)?)) + Ok(Some(plan.with_new_exprs(plan.expressions(), &new_inputs)?)) } else { Ok(None) } } +pub(crate) fn collect_subquery_cols( + exprs: &[Expr], + subquery_schema: DFSchemaRef, +) -> Result> { + exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { + let mut using_cols: Vec = vec![]; + for col in expr.to_columns()?.into_iter() { + if subquery_schema.has_column(&col) { + using_cols.push(col); + } + } + + cols.extend(using_cols); + Result::<_>::Ok(cols) + }) +} + +pub(crate) fn replace_qualified_name( + expr: Expr, + cols: &BTreeSet, + subquery_alias: &str, +) -> Result { + let alias_cols: Vec = cols + .iter() + .map(|col| { + Column::from_qualified_name(format!("{}.{}", subquery_alias, col.name)) + }) + .collect(); + let replace_map: HashMap<&Column, &Column> = + cols.iter().zip(alias_cols.iter()).collect(); + + replace_col(expr, &replace_map) +} + +/// Log the plan in debug/tracing mode after some part of the optimizer runs +pub fn log_plan(description: &str, plan: &LogicalPlan) { + debug!("{description}:\n{}\n", plan.display_indent()); + trace!("{description}::\n{}\n", plan.display_indent_schema()); +} + /// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` /// /// See [`split_conjunction_owned`] for more details and an example. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_conjunction` instead" +)] pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { - split_conjunction_impl(expr, vec![]) -} - -fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { - match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::And, - left, - }) => { - let exprs = split_conjunction_impl(left, exprs); - split_conjunction_impl(right, exprs) - } - Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs), - other => { - exprs.push(other); - exprs - } - } + expr_utils::split_conjunction(expr) } /// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` @@ -104,8 +124,12 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// // use split_conjunction_owned to split them /// assert_eq!(split_conjunction_owned(expr), split); /// ``` +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_conjunction_owned` instead" +)] pub fn split_conjunction_owned(expr: Expr) -> Vec { - split_binary_owned(expr, Operator::And) + expr_utils::split_conjunction_owned(expr) } /// Splits an owned binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` @@ -130,53 +154,23 @@ pub fn split_conjunction_owned(expr: Expr) -> Vec { /// // use split_binary_owned to split them /// assert_eq!(split_binary_owned(expr, Operator::Plus), split); /// ``` +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_binary_owned` instead" +)] pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { - split_binary_owned_impl(expr, op, vec![]) -} - -fn split_binary_owned_impl( - expr: Expr, - operator: Operator, - mut exprs: Vec, -) -> Vec { - match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { - let exprs = split_binary_owned_impl(*left, operator, exprs); - split_binary_owned_impl(*right, operator, exprs) - } - Expr::Alias(Alias { expr, .. }) => { - split_binary_owned_impl(*expr, operator, exprs) - } - other => { - exprs.push(other); - exprs - } - } + expr_utils::split_binary_owned(expr, op) } /// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` /// /// See [`split_binary_owned`] for more details and an example. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_binary` instead" +)] pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { - split_binary_impl(expr, op, vec![]) -} - -fn split_binary_impl<'a>( - expr: &'a Expr, - operator: Operator, - mut exprs: Vec<&'a Expr>, -) -> Vec<&'a Expr> { - match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { - let exprs = split_binary_impl(left, operator, exprs); - split_binary_impl(right, operator, exprs) - } - Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs), - other => { - exprs.push(other); - exprs - } - } + expr_utils::split_binary(expr, op) } /// Combines an array of filter expressions into a single filter @@ -201,8 +195,12 @@ fn split_binary_impl<'a>( /// // use conjunction to join them together with `AND` /// assert_eq!(conjunction(split), Some(expr)); /// ``` +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::conjunction` instead" +)] pub fn conjunction(filters: impl IntoIterator) -> Option { - filters.into_iter().reduce(|accum, expr| accum.and(expr)) + expr_utils::conjunction(filters) } /// Combines an array of filter expressions into a single filter @@ -210,25 +208,22 @@ pub fn conjunction(filters: impl IntoIterator) -> Option { /// logical OR. /// /// Returns None if the filters array is empty. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::disjunction` instead" +)] pub fn disjunction(filters: impl IntoIterator) -> Option { - filters.into_iter().reduce(|accum, expr| accum.or(expr)) + expr_utils::disjunction(filters) } /// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with /// its predicate be all `predicates` ANDed. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::add_filter` instead" +)] pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result { - // reduce filters to a single filter with an AND - let predicate = predicates - .iter() - .skip(1) - .fold(predicates[0].clone(), |acc, predicate| { - and(acc, (*predicate).to_owned()) - }); - - Ok(LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(plan), - )?)) + expr_utils::add_filter(plan, predicates) } /// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and @@ -241,22 +236,12 @@ pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result) -> Result<(Vec, Vec)> { - let mut joins = vec![]; - let mut others = vec![]; - for filter in exprs.into_iter() { - // If the expression contains correlated predicates, add it to join filters - if filter.contains_outer() { - if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right)) - { - joins.push(strip_outer_reference((*filter).clone())); - } - } else { - others.push((*filter).clone()); - } - } - - Ok((joins, others)) + expr_utils::find_join_exprs(exprs) } /// Returns the first (and only) element in a slice, or an error @@ -268,215 +253,19 @@ pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec, Vec)> { /// # Return value /// /// The first element, or an error +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::only_or_err` instead" +)] pub fn only_or_err(slice: &[T]) -> Result<&T> { - match slice { - [it] => Ok(it), - [] => plan_err!("No items found!"), - _ => plan_err!("More than one item found!"), - } + expr_utils::only_or_err(slice) } /// merge inputs schema into a single schema. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::merge_schema` instead" +)] pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { - if inputs.len() == 1 { - inputs[0].schema().clone().as_ref().clone() - } else { - inputs.iter().map(|input| input.schema()).fold( - DFSchema::empty(), - |mut lhs, rhs| { - lhs.merge(rhs); - lhs - }, - ) - } -} - -pub(crate) fn collect_subquery_cols( - exprs: &[Expr], - subquery_schema: DFSchemaRef, -) -> Result> { - exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { - let mut using_cols: Vec = vec![]; - for col in expr.to_columns()?.into_iter() { - if subquery_schema.has_column(&col) { - using_cols.push(col); - } - } - - cols.extend(using_cols); - Result::<_>::Ok(cols) - }) -} - -pub(crate) fn replace_qualified_name( - expr: Expr, - cols: &BTreeSet, - subquery_alias: &str, -) -> Result { - let alias_cols: Vec = cols - .iter() - .map(|col| { - Column::from_qualified_name(format!("{}.{}", subquery_alias, col.name)) - }) - .collect(); - let replace_map: HashMap<&Column, &Column> = - cols.iter().zip(alias_cols.iter()).collect(); - - replace_col(expr, &replace_map) -} - -/// Log the plan in debug/tracing mode after some part of the optimizer runs -pub fn log_plan(description: &str, plan: &LogicalPlan) { - debug!("{description}:\n{}\n", plan.display_indent()); - trace!("{description}::\n{}\n", plan.display_indent_schema()); -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::datatypes::DataType; - use datafusion_common::Column; - use datafusion_expr::expr::Cast; - use datafusion_expr::{col, lit, utils::expr_to_columns}; - use std::collections::HashSet; - - #[test] - fn test_split_conjunction() { - let expr = col("a"); - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr]); - } - - #[test] - fn test_split_conjunction_two() { - let expr = col("a").eq(lit(5)).and(col("b")); - let expr1 = col("a").eq(lit(5)); - let expr2 = col("b"); - - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr1, &expr2]); - } - - #[test] - fn test_split_conjunction_alias() { - let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias")); - let expr1 = col("a").eq(lit(5)); - let expr2 = col("b"); // has no alias - - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr1, &expr2]); - } - - #[test] - fn test_split_conjunction_or() { - let expr = col("a").eq(lit(5)).or(col("b")); - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr]); - } - - #[test] - fn test_split_binary_owned() { - let expr = col("a"); - assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]); - } - - #[test] - fn test_split_binary_owned_two() { - assert_eq!( - split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), - vec![col("a").eq(lit(5)), col("b")] - ); - } - - #[test] - fn test_split_binary_owned_different_op() { - let expr = col("a").eq(lit(5)).or(col("b")); - assert_eq!( - // expr is connected by OR, but pass in AND - split_binary_owned(expr.clone(), Operator::And), - vec![expr] - ); - } - - #[test] - fn test_split_conjunction_owned() { - let expr = col("a"); - assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); - } - - #[test] - fn test_split_conjunction_owned_two() { - assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))), - vec![col("a").eq(lit(5)), col("b")] - ); - } - - #[test] - fn test_split_conjunction_owned_alias() { - assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), - vec![ - col("a").eq(lit(5)), - // no alias on b - col("b"), - ] - ); - } - - #[test] - fn test_conjunction_empty() { - assert_eq!(conjunction(vec![]), None); - } - - #[test] - fn test_conjunction() { - // `[A, B, C]` - let expr = conjunction(vec![col("a"), col("b"), col("c")]); - - // --> `(A AND B) AND C` - assert_eq!(expr, Some(col("a").and(col("b")).and(col("c")))); - - // which is different than `A AND (B AND C)` - assert_ne!(expr, Some(col("a").and(col("b").and(col("c"))))); - } - - #[test] - fn test_disjunction_empty() { - assert_eq!(disjunction(vec![]), None); - } - - #[test] - fn test_disjunction() { - // `[A, B, C]` - let expr = disjunction(vec![col("a"), col("b"), col("c")]); - - // --> `(A OR B) OR C` - assert_eq!(expr, Some(col("a").or(col("b")).or(col("c")))); - - // which is different than `A OR (B OR C)` - assert_ne!(expr, Some(col("a").or(col("b").or(col("c"))))); - } - - #[test] - fn test_split_conjunction_owned_or() { - let expr = col("a").eq(lit(5)).or(col("b")); - assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); - } - - #[test] - fn test_collect_expr() -> Result<()> { - let mut accum: HashSet = HashSet::new(); - expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), - &mut accum, - )?; - expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), - &mut accum, - )?; - assert_eq!(1, accum.len()); - assert!(accum.contains(&Column::from_name("a"))); - Ok(()) - } + expr_utils::merge_schema(inputs) } diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 46023cfc30bc..d857c6154ea9 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; -use chrono::{DateTime, NaiveDateTime, Utc}; use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; @@ -28,9 +31,8 @@ use datafusion_sql::sqlparser::ast::Statement; use datafusion_sql::sqlparser::dialect::GenericDialect; use datafusion_sql::sqlparser::parser::Parser; use datafusion_sql::TableReference; -use std::any::Any; -use std::collections::HashMap; -use std::sync::Arc; + +use chrono::{DateTime, NaiveDateTime, Utc}; #[cfg(test)] #[ctor::ctor] @@ -185,8 +187,9 @@ fn between_date32_plus_interval() -> Result<()> { let plan = test_sql(sql)?; let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ - \n Filter: test.col_date32 >= Date32(\"10303\") AND test.col_date32 <= Date32(\"10393\")\ - \n TableScan: test projection=[col_date32]"; + \n Projection: \ + \n Filter: test.col_date32 >= Date32(\"10303\") AND test.col_date32 <= Date32(\"10393\")\ + \n TableScan: test projection=[col_date32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -198,8 +201,9 @@ fn between_date64_plus_interval() -> Result<()> { let plan = test_sql(sql)?; let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ - \n Filter: test.col_date64 >= Date64(\"890179200000\") AND test.col_date64 <= Date64(\"897955200000\")\ - \n TableScan: test projection=[col_date64]"; + \n Projection: \ + \n Filter: test.col_date64 >= Date64(\"890179200000\") AND test.col_date64 <= Date64(\"897955200000\")\ + \n TableScan: test projection=[col_date64]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -322,11 +326,10 @@ fn push_down_filter_groupby_expr_contains_alias() { fn test_same_name_but_not_ambiguous() { let sql = "SELECT t1.col_int32 AS col_int32 FROM test t1 intersect SELECT col_int32 FROM test t2"; let plan = test_sql(sql).unwrap(); - let expected = "LeftSemi Join: col_int32 = t2.col_int32\ - \n Aggregate: groupBy=[[col_int32]], aggr=[[]]\ - \n Projection: t1.col_int32 AS col_int32\ - \n SubqueryAlias: t1\ - \n TableScan: test projection=[col_int32]\ + let expected = "LeftSemi Join: t1.col_int32 = t2.col_int32\ + \n Aggregate: groupBy=[[t1.col_int32]], aggr=[[]]\ + \n SubqueryAlias: t1\ + \n TableScan: test projection=[col_int32]\ \n SubqueryAlias: t2\ \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{plan:?}")); @@ -339,8 +342,8 @@ fn test_sql(sql: &str) -> Result { let statement = &ast[0]; // create a logical query plan - let schema_provider = MySchemaProvider::default(); - let sql_to_rel = SqlToRel::new(&schema_provider); + let context_provider = MyContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); // hard code the return value of now() @@ -357,12 +360,12 @@ fn test_sql(sql: &str) -> Result { } #[derive(Default)] -struct MySchemaProvider { +struct MyContextProvider { options: ConfigOptions, } -impl ContextProvider for MySchemaProvider { - fn get_table_provider(&self, name: TableReference) -> Result> { +impl ContextProvider for MyContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { let table_name = name.table(); if table_name.starts_with("test") { let schema = Schema::new_with_metadata( diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 2c0ddc692d28..d237c68657a1 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -19,9 +19,9 @@ name = "datafusion-physical-expr" description = "Physical expression implementation for DataFusion query engine" keywords = ["arrow", "query", "sql"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -34,34 +34,37 @@ path = "src/lib.rs" [features] crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] -default = ["crypto_expressions", "regex_expressions", "unicode_expressions", "encoding_expressions"] +default = ["crypto_expressions", "regex_expressions", "unicode_expressions", "encoding_expressions", +] encoding_expressions = ["base64", "hex"] regex_expressions = ["regex"] unicode_expressions = ["unicode-segmentation"] [dependencies] -ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } base64 = { version = "0.21", optional = true } blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } chrono = { workspace = true } -datafusion-common = { path = "../common", version = "31.0.0", default-features = false } -datafusion-expr = { path = "../expr", version = "31.0.0" } +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } half = { version = "2.1", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } hex = { version = "0.4", optional = true } -indexmap = "2.0.0" -itertools = { version = "0.11", features = ["use_std"] } -libc = "0.2.140" -log = "^0.4" +indexmap = { workspace = true } +itertools = { version = "0.12", features = ["use_std"] } +log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } paste = "^1.0" petgraph = "0.6.2" -rand = "0.8" +rand = { workspace = true } regex = { version = "1.8", optional = true } sha2 = { version = "^0.10.1", optional = true } unicode-segmentation = { version = "^1.7.1", optional = true } @@ -69,8 +72,8 @@ uuid = { version = "^1.2", features = ["v4"] } [dev-dependencies] criterion = "0.5" -rand = "0.8" -rstest = "0.18.0" +rand = { workspace = true } +rstest = { workspace = true } [[bench]] harness = false diff --git a/datafusion/physical-expr/README.md b/datafusion/physical-expr/README.md index a887d3eb29fe..424256c77e7e 100644 --- a/datafusion/physical-expr/README.md +++ b/datafusion/physical-expr/README.md @@ -19,7 +19,7 @@ # DataFusion Physical Expressions -[DataFusion](df) is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. This crate is a submodule of DataFusion that provides data types and utilities for physical expressions. diff --git a/datafusion/physical-expr/benches/in_list.rs b/datafusion/physical-expr/benches/in_list.rs index db017326083a..90bfc5efb61e 100644 --- a/datafusion/physical-expr/benches/in_list.rs +++ b/datafusion/physical-expr/benches/in_list.rs @@ -57,7 +57,7 @@ fn do_benches( .collect(); let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Utf8(Some(random_string(&mut rng, string_length)))) + .map(|_| ScalarValue::from(random_string(&mut rng, string_length))) .collect(); do_bench( diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs index aa4749f64ae9..15c0fb3ace4d 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs @@ -18,7 +18,7 @@ use crate::aggregate::tdigest::TryIntoF64; use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE}; use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::{format_state_name, Literal}; +use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; use arrow::{ array::{ @@ -27,11 +27,13 @@ use arrow::{ }, datatypes::{DataType, Field}, }; +use arrow_array::RecordBatch; +use arrow_schema::Schema; use datafusion_common::{ downcast_value, exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, ColumnarValue}; use std::{any::Any, iter, sync::Arc}; /// APPROX_PERCENTILE_CONT aggregate expression @@ -131,18 +133,22 @@ impl PartialEq for ApproxPercentileCont { } } +fn get_lit_value(expr: &Arc) -> Result { + let empty_schema = Schema::empty(); + let empty_batch = RecordBatch::new_empty(Arc::new(empty_schema)); + let result = expr.evaluate(&empty_batch)?; + match result { + ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( + "The expr {:?} can't be evaluated to scalar value", + expr + ))), + ColumnarValue::Scalar(scalar_value) => Ok(scalar_value), + } +} + fn validate_input_percentile_expr(expr: &Arc) -> Result { - // Extract the desired percentile literal - let lit = expr - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "desired percentile argument must be float literal".to_string(), - ) - })? - .value(); - let percentile = match lit { + let lit = get_lit_value(expr)?; + let percentile = match &lit { ScalarValue::Float32(Some(q)) => *q as f64, ScalarValue::Float64(Some(q)) => *q, got => return not_impl_err!( @@ -161,17 +167,8 @@ fn validate_input_percentile_expr(expr: &Arc) -> Result { } fn validate_input_max_size_expr(expr: &Arc) -> Result { - // Extract the desired percentile literal - let lit = expr - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "desired percentile argument must be float literal".to_string(), - ) - })? - .value(); - let max_size = match lit { + let lit = get_lit_value(expr)?; + let max_size = match &lit { ScalarValue::UInt8(Some(q)) => *q as usize, ScalarValue::UInt16(Some(q)) => *q as usize, ScalarValue::UInt32(Some(q)) => *q as usize, diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 0cf39888f133..91d5c867d312 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -22,8 +22,11 @@ use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; +use arrow_array::Array; +use datafusion_common::cast::as_list_array; +use datafusion_common::utils::array_into_list_array; +use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::Accumulator; use std::any::Any; use std::sync::Arc; @@ -31,9 +34,14 @@ use std::sync::Arc; /// ARRAY_AGG aggregate expression #[derive(Debug)] pub struct ArrayAgg { + /// Column name name: String, + /// The DataType for the input expression input_data_type: DataType, + /// The input expression expr: Arc, + /// If the input expression can have NULLs + nullable: bool, } impl ArrayAgg { @@ -42,11 +50,13 @@ impl ArrayAgg { expr: Arc, name: impl Into, data_type: DataType, + nullable: bool, ) -> Self { Self { name: name.into(), - expr, input_data_type: data_type, + expr, + nullable, } } } @@ -59,8 +69,9 @@ impl AggregateExpr for ArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )) } @@ -74,7 +85,7 @@ impl AggregateExpr for ArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "array_agg"), Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )]) } @@ -102,7 +113,7 @@ impl PartialEq for ArrayAgg { #[derive(Debug)] pub(crate) struct ArrayAggAccumulator { - values: Vec, + values: Vec, datatype: DataType, } @@ -117,34 +128,29 @@ impl ArrayAggAccumulator { } impl Accumulator for ArrayAggAccumulator { + // Append value like Int64Array(1,2,3) fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { if values.is_empty() { return Ok(()); } assert!(values.len() == 1, "array_agg can only take 1 param!"); - let arr = &values[0]; - (0..arr.len()).try_for_each(|index| { - let scalar = ScalarValue::try_from_array(arr, index)?; - self.values.push(scalar); - Ok(()) - }) + let val = values[0].clone(); + self.values.push(val); + Ok(()) } + // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); } assert!(states.len() == 1, "array_agg states must be singleton!"); - let arr = &states[0]; - (0..arr.len()).try_for_each(|index| { - let scalar = ScalarValue::try_from_array(arr, index)?; - if let ScalarValue::List(Some(values), _) = scalar { - self.values.extend(values); - Ok(()) - } else { - internal_err!("array_agg state must be list!") - } - }) + + let list_arr = as_list_array(&states[0])?; + for arr in list_arr.iter().flatten() { + self.values.push(arr); + } + Ok(()) } fn state(&self) -> Result> { @@ -152,15 +158,30 @@ impl Accumulator for ArrayAggAccumulator { } fn evaluate(&self) -> Result { - Ok(ScalarValue::new_list( - Some(self.values.clone()), - self.datatype.clone(), - )) + // Transform Vec to ListArr + + let element_arrays: Vec<&dyn Array> = + self.values.iter().map(|a| a.as_ref()).collect(); + + if element_arrays.is_empty() { + let arr = ScalarValue::new_list(&[], &self.datatype); + return Ok(ScalarValue::List(arr)); + } + + let concated_array = arrow::compute::concat(&element_arrays)?; + let list_array = array_into_list_array(concated_array); + + Ok(ScalarValue::List(Arc::new(list_array))) } fn size(&self) -> usize { - std::mem::size_of_val(self) + ScalarValue::size_of_vec(&self.values) - - std::mem::size_of_val(&self.values) + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum::() + self.datatype.size() - std::mem::size_of_val(&self.datatype) } @@ -171,81 +192,110 @@ mod tests { use super::*; use crate::expressions::col; use crate::expressions::tests::aggregate; - use crate::generic_test_op; use arrow::array::ArrayRef; use arrow::array::Int32Array; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; + use arrow_array::Array; + use arrow_array::ListArray; + use arrow_buffer::OffsetBuffer; + use datafusion_common::DataFusionError; use datafusion_common::Result; + macro_rules! test_op { + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { + test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type()) + }; + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ + let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; + + let agg = Arc::new(<$OP>::new( + col("a", &schema)?, + "bla".to_string(), + $EXPECTED_DATATYPE, + true, + )); + let actual = aggregate(&batch, agg)?; + let expected = ScalarValue::from($EXPECTED); + + assert_eq!(expected, actual); + + Ok(()) as Result<(), DataFusionError> + }}; + } + #[test] fn array_agg_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - let list = ScalarValue::new_list( - Some(vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(3)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ]), - DataType::Int32, - ); + let list = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + ])]); + let list = ScalarValue::List(Arc::new(list)); - generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) + test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) } #[test] fn array_agg_nested() -> Result<()> { - let l1 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(4), + Some(5), + ])]); + let l1 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([a1.len() + a2.len()]), + arrow::compute::concat(&[&a1, &a2])?, + None, ); - let l2 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ScalarValue::from(6i32)]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(7), + Some(8), + ])]); + let l2 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([a1.len() + a2.len()]), + arrow::compute::concat(&[&a1, &a2])?, + None, ); - let l3 = ScalarValue::new_list( - Some(vec![ScalarValue::new_list( - Some(vec![ScalarValue::from(9i32)]), - DataType::Int32, - )]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); + let l3 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([a1.len()]), + arrow::compute::concat(&[&a1])?, + None, ); - let list = ScalarValue::new_list( - Some(vec![l1.clone(), l2.clone(), l3.clone()]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let list = ListArray::new( + Arc::new(Field::new("item", l1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([l1.len() + l2.len() + l3.len()]), + arrow::compute::concat(&[&l1, &l2, &l3])?, + None, ); + let list = ScalarValue::List(Arc::new(list)); + let l1 = ScalarValue::List(Arc::new(l1)); + let l2 = ScalarValue::List(Arc::new(l2)); + let l3 = ScalarValue::List(Arc::new(l3)); let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); - generic_test_op!( + test_op!( array, DataType::List(Arc::new(Field::new_list( "item", diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 422eecd20155..1efae424cc69 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -22,13 +22,13 @@ use std::any::Any; use std::fmt::Debug; use std::sync::Arc; -use arrow::array::{Array, ArrayRef}; +use arrow::array::ArrayRef; use std::collections::HashSet; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; /// Expression for a ARRAY_AGG(DISTINCT) aggregation. @@ -40,6 +40,8 @@ pub struct DistinctArrayAgg { input_data_type: DataType, /// The input expression expr: Arc, + /// If the input expression can have NULLs + nullable: bool, } impl DistinctArrayAgg { @@ -48,12 +50,14 @@ impl DistinctArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, + nullable: bool, ) -> Self { let name = name.into(); Self { name, - expr, input_data_type, + expr, + nullable, } } } @@ -67,8 +71,9 @@ impl AggregateExpr for DistinctArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )) } @@ -82,7 +87,7 @@ impl AggregateExpr for DistinctArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "distinct_array_agg"), Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )]) } @@ -125,22 +130,18 @@ impl DistinctArrayAggAccumulator { impl Accumulator for DistinctArrayAggAccumulator { fn state(&self) -> Result> { - Ok(vec![ScalarValue::new_list( - Some(self.values.clone().into_iter().collect()), - self.datatype.clone(), - )]) + Ok(vec![self.evaluate()?]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { assert_eq!(values.len(), 1, "batch input should only include 1 column!"); let array = &values[0]; - (0..array.len()).try_for_each(|i| { - if !array.is_null(i) { - self.values.insert(ScalarValue::try_from_array(array, i)?); - } - Ok(()) - }) + let scalars = ScalarValue::convert_array_to_scalar_vec(array)?; + for scalar in scalars { + self.values.extend(scalar) + } + Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { @@ -154,25 +155,18 @@ impl Accumulator for DistinctArrayAggAccumulator { "array_agg_distinct states must contain single array" ); - let array = &states[0]; - (0..array.len()).try_for_each(|i| { - let scalar = ScalarValue::try_from_array(array, i)?; - if let ScalarValue::List(Some(values), _) = scalar { - self.values.extend(values); - Ok(()) - } else { - internal_err!("array_agg_distinct state must be list") - } - })?; + let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?; + for scalars in scalar_vec { + self.values.extend(scalars) + } Ok(()) } fn evaluate(&self) -> Result { - Ok(ScalarValue::new_list( - Some(self.values.clone().into_iter().collect()), - self.datatype.clone(), - )) + let values: Vec = self.values.iter().cloned().collect(); + let arr = ScalarValue::new_list(&values, &self.datatype); + Ok(ScalarValue::List(arr)) } fn size(&self) -> usize { @@ -185,34 +179,56 @@ impl Accumulator for DistinctArrayAggAccumulator { #[cfg(test)] mod tests { + use super::*; - use crate::aggregate::utils::get_accum_scalar_values_as_arrays; use crate::expressions::col; use crate::expressions::tests::aggregate; use arrow::array::{ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; + use arrow_array::cast::as_list_array; + use arrow_array::types::Int32Type; + use arrow_array::{Array, ListArray}; + use arrow_buffer::OffsetBuffer; + use datafusion_common::utils::array_into_list_array; + use datafusion_common::{internal_err, DataFusionError}; + + // arrow::compute::sort cann't sort ListArray directly, so we need to sort the inner primitive array and wrap it back into ListArray. + fn sort_list_inner(arr: ScalarValue) -> ScalarValue { + let arr = match arr { + ScalarValue::List(arr) => { + let list_arr = as_list_array(&arr); + list_arr.value(0) + } + _ => { + panic!("Expected ScalarValue::List, got {:?}", arr) + } + }; + + let arr = arrow::compute::sort(&arr, None).unwrap(); + let list_arr = array_into_list_array(arr); + ScalarValue::List(Arc::new(list_arr)) + } fn compare_list_contents(expected: ScalarValue, actual: ScalarValue) -> Result<()> { - match (expected, actual) { - (ScalarValue::List(Some(mut e), _), ScalarValue::List(Some(mut a), _)) => { - // workaround lack of Ord of ScalarValue - let cmp = |a: &ScalarValue, b: &ScalarValue| { - a.partial_cmp(b).expect("Can compare ScalarValues") - }; - - e.sort_by(cmp); - a.sort_by(cmp); - // Check that the inputs are the same - assert_eq!(e, a); + let actual = sort_list_inner(actual); + + match (&expected, &actual) { + (ScalarValue::List(arr1), ScalarValue::List(arr2)) => { + if arr1.eq(arr2) { + Ok(()) + } else { + internal_err!( + "Actual value {:?} not found in expected values {:?}", + actual, + expected + ) + } } _ => { - return Err(DataFusionError::Internal( - "Expected scalar lists as inputs".to_string(), - )); + internal_err!("Expected scalar lists as inputs") } } - Ok(()) } fn check_distinct_array_agg( @@ -227,6 +243,7 @@ mod tests { col("a", &schema)?, "bla".to_string(), datatype, + true, )); let actual = aggregate(&batch, agg)?; @@ -244,6 +261,7 @@ mod tests { col("a", &schema)?, "bla".to_string(), datatype, + true, )); let mut accum1 = agg.create_accumulator()?; @@ -252,8 +270,8 @@ mod tests { accum1.update_batch(&[input1])?; accum2.update_batch(&[input2])?; - let state = get_accum_scalar_values_as_arrays(accum2.as_ref())?; - accum1.merge_batch(&state)?; + let array = accum2.state()?[0].raw_data()?; + accum1.merge_batch(&[array])?; let actual = accum1.evaluate()?; @@ -263,19 +281,18 @@ mod tests { #[test] fn distinct_array_agg_i32() -> Result<()> { let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); - - let out = ScalarValue::new_list( - Some(vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(7)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ]), - DataType::Int32, - ); - - check_distinct_array_agg(col, out, DataType::Int32) + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(4), + Some(5), + Some(7), + ])]), + )); + + check_distinct_array_agg(col, expected, DataType::Int32) } #[test] @@ -283,78 +300,90 @@ mod tests { let col1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); let col2: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 7, 8, 4])); - let out = ScalarValue::new_list( - Some(vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(3)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ScalarValue::Int32(Some(7)), - ScalarValue::Int32(Some(8)), - ]), - DataType::Int32, - ); - - check_merge_distinct_array_agg(col1, col2, out, DataType::Int32) + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(7), + Some(8), + ])]), + )); + + check_merge_distinct_array_agg(col1, col2, expected, DataType::Int32) } #[test] fn distinct_array_agg_nested() -> Result<()> { // [[1, 2, 3], [4, 5]] - let l1 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(4), + Some(5), + ])]); + let l1 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([2]), + arrow::compute::concat(&[&a1, &a2]).unwrap(), + None, ); // [[6], [7, 8]] - let l2 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ScalarValue::from(6i32)]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(7), + Some(8), + ])]); + let l2 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([2]), + arrow::compute::concat(&[&a1, &a2]).unwrap(), + None, ); // [[9]] - let l3 = ScalarValue::new_list( - Some(vec![ScalarValue::new_list( - Some(vec![ScalarValue::from(9i32)]), - DataType::Int32, - )]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); + let l3 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([1]), + Arc::new(a1), + None, ); - let list = ScalarValue::new_list( - Some(vec![l1.clone(), l2.clone(), l3.clone()]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - ); + let l1 = ScalarValue::List(Arc::new(l1)); + let l2 = ScalarValue::List(Arc::new(l2)); + let l3 = ScalarValue::List(Arc::new(l3)); // Duplicate l1 in the input array and check that it is deduped in the output. let array = ScalarValue::iter_to_array(vec![l1.clone(), l2, l3, l1]).unwrap(); + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + ])]), + )); + check_distinct_array_agg( array, - list, + expected, DataType::List(Arc::new(Field::new_list( "item", Field::new("item", DataType::Int32, true), @@ -366,62 +395,66 @@ mod tests { #[test] fn merge_distinct_array_agg_nested() -> Result<()> { // [[1, 2], [3, 4]] - let l1 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ScalarValue::from(1i32), ScalarValue::from(2i32)]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(3i32), ScalarValue::from(4i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + ])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(3), + Some(4), + ])]); + let l1 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([2]), + arrow::compute::concat(&[&a1, &a2]).unwrap(), + None, ); - // [[5]] - let l2 = ScalarValue::new_list( - Some(vec![ScalarValue::new_list( - Some(vec![ScalarValue::from(5i32)]), - DataType::Int32, - )]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(5)])]); + let l2 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([1]), + Arc::new(a1), + None, ); // [[6, 7], [8]] - let l3 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ScalarValue::from(6i32), ScalarValue::from(7i32)]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(8i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(6), + Some(7), + ])]); + let a2 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(8)])]); + let l3 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([2]), + arrow::compute::concat(&[&a1, &a2]).unwrap(), + None, ); - let expected = ScalarValue::new_list( - Some(vec![l1.clone(), l2.clone(), l3.clone()]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - ); + let l1 = ScalarValue::List(Arc::new(l1)); + let l2 = ScalarValue::List(Arc::new(l2)); + let l3 = ScalarValue::List(Arc::new(l3)); // Duplicate l1 in the input array and check that it is deduped in the output. let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2]).unwrap(); let input2 = ScalarValue::iter_to_array(vec![l1, l3]).unwrap(); - check_merge_distinct_array_agg( - input1, - input2, - expected, - DataType::List(Arc::new(Field::new_list( - "item", - Field::new("item", DataType::Int32, true), - true, - ))), - ) + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + ])]), + )); + + check_merge_distinct_array_agg(input1, input2, expected, DataType::Int32) } } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index bf5dbfb4fda9..eb5ae8b0b0c3 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -30,10 +30,11 @@ use crate::{AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; -use arrow_array::{Array, ListArray}; +use arrow_array::cast::AsArray; +use arrow_array::Array; use arrow_schema::{Fields, SortOptions}; use datafusion_common::utils::{compare_rows, get_row_at_idx}; -use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; use itertools::izip; @@ -47,10 +48,17 @@ use itertools::izip; /// and that can merge aggregations from multiple partitions. #[derive(Debug)] pub struct OrderSensitiveArrayAgg { + /// Column name name: String, + /// The DataType for the input expression input_data_type: DataType, - order_by_data_types: Vec, + /// The input expression expr: Arc, + /// If the input expression can have NULLs + nullable: bool, + /// Ordering data types + order_by_data_types: Vec, + /// Ordering requirement ordering_req: LexOrdering, } @@ -60,13 +68,15 @@ impl OrderSensitiveArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, + nullable: bool, order_by_data_types: Vec, ordering_req: LexOrdering, ) -> Self { Self { name: name.into(), - expr, input_data_type, + expr, + nullable, order_by_data_types, ordering_req, } @@ -81,8 +91,9 @@ impl AggregateExpr for OrderSensitiveArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )) } @@ -98,13 +109,13 @@ impl AggregateExpr for OrderSensitiveArrayAgg { let mut fields = vec![Field::new_list( format_state_name(&self.name, "array_agg"), Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, // This should be the same as field() )]; let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); fields.push(Field::new_list( format_state_name(&self.name, "array_agg_orderings"), Field::new("item", DataType::Struct(Fields::from(orderings)), true), - false, + self.nullable, )); Ok(fields) } @@ -181,12 +192,14 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { if values.is_empty() { return Ok(()); } + let n_row = values[0].len(); for index in 0..n_row { let row = get_row_at_idx(values, index)?; self.values.push(row[0].clone()); self.ordering_values.push(row[1..].to_vec()); } + Ok(()) } @@ -197,10 +210,11 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { // First entry in the state is the aggregation result. let array_agg_values = &states[0]; // 2nd entry stores values received for ordering requirement columns, for each aggregation value inside ARRAY_AGG list. - // For each `ScalarValue` inside ARRAY_AGG list, we will receive a `Vec` that stores + // For each `StructArray` inside ARRAY_AGG list, we will receive an `Array` that stores // values received from its ordering requirement expression. (This information is necessary for during merging). let agg_orderings = &states[1]; - if agg_orderings.as_any().is::() { + + if let Some(agg_orderings) = agg_orderings.as_list_opt::() { // Stores ARRAY_AGG results coming from each partition let mut partition_values = vec![]; // Stores ordering requirement expression results coming from each partition @@ -209,20 +223,32 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { // Existing values should be merged also. partition_values.push(self.values.clone()); partition_ordering_values.push(self.ordering_values.clone()); - for index in 0..agg_orderings.len() { - let ordering = ScalarValue::try_from_array(agg_orderings, index)?; - // Ordering requirement expression values for each entry in the ARRAY_AGG list - let other_ordering_values = - self.convert_array_agg_to_orderings(ordering)?; - // ARRAY_AGG result. (It is a `ScalarValue::List` under the hood, it stores `Vec`) - let array_agg_res = ScalarValue::try_from_array(array_agg_values, index)?; - if let ScalarValue::List(Some(other_values), _) = array_agg_res { - partition_values.push(other_values); - partition_ordering_values.push(other_ordering_values); - } else { - return internal_err!("ARRAY_AGG state must be list!"); - } + + let array_agg_res = + ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; + + for v in array_agg_res.into_iter() { + partition_values.push(v); + } + + let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; + + for partition_ordering_rows in orderings.into_iter() { + // Extract value from struct to ordering_rows for each group/partition + let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| { + if let ScalarValue::Struct(Some(ordering_columns_per_row), _) = ordering_row { + Ok(ordering_columns_per_row) + } else { + exec_err!( + "Expects to receive ScalarValue::Struct(Some(..), _) but got:{:?}", + ordering_row.data_type() + ) + } + }).collect::>>()?; + + partition_ordering_values.push(ordering_value); } + let sort_options = self .ordering_req .iter() @@ -248,10 +274,8 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } fn evaluate(&self) -> Result { - Ok(ScalarValue::new_list( - Some(self.values.clone()), - self.datatypes[0].clone(), - )) + let arr = ScalarValue::new_list(&self.values, &self.datatypes[0]); + Ok(ScalarValue::List(arr)) } fn size(&self) -> usize { @@ -280,33 +304,11 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } impl OrderSensitiveArrayAggAccumulator { - fn convert_array_agg_to_orderings( - &self, - in_data: ScalarValue, - ) -> Result>> { - if let ScalarValue::List(Some(list_vals), _field_ref) = in_data { - list_vals.into_iter().map(|struct_vals| { - if let ScalarValue::Struct(Some(orderings), _fields) = struct_vals { - Ok(orderings) - } else { - exec_err!( - "Expects to receive ScalarValue::Struct(Some(..), _) but got:{:?}", - struct_vals.data_type() - ) - } - }).collect::>>() - } else { - exec_err!( - "Expects to receive ScalarValue::List(Some(..), _) but got:{:?}", - in_data.data_type() - ) - } - } - fn evaluate_orderings(&self) -> Result { let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); let struct_field = Fields::from(fields.clone()); - let orderings = self + + let orderings: Vec = self .ordering_values .iter() .map(|ordering| { @@ -314,7 +316,10 @@ impl OrderSensitiveArrayAggAccumulator { }) .collect(); let struct_type = DataType::Struct(Fields::from(fields)); - Ok(ScalarValue::new_list(Some(orderings), struct_type)) + + // Wrap in List, so we have the same data structure ListArray(StructArray..) for group by cases + let arr = ScalarValue::new_list(&orderings, &struct_type); + Ok(ScalarValue::List(arr)) } } diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 92c806f76f3c..91f2fb952dce 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -21,6 +21,7 @@ use arrow::array::{AsArray, PrimitiveBuilder}; use log::debug; use std::any::Any; +use std::fmt::Debug; use std::sync::Arc; use crate::aggregate::groups_accumulator::accumulate::NullState; @@ -33,15 +34,17 @@ use arrow::{ array::{ArrayRef, UInt64Array}, datatypes::Field, }; +use arrow_array::types::{Decimal256Type, DecimalType}; use arrow_array::{ Array, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, PrimitiveArray, }; +use arrow_buffer::{i256, ArrowNativeType}; use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::avg_return_type; use datafusion_expr::Accumulator; use super::groups_accumulator::EmitTo; -use super::utils::Decimal128Averager; +use super::utils::DecimalAverager; /// AVG aggregate expression #[derive(Debug, Clone)] @@ -88,7 +91,19 @@ impl AggregateExpr for Avg { ( Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale), - ) => Ok(Box::new(DecimalAvgAccumulator { + ) => Ok(Box::new(DecimalAvgAccumulator:: { + sum: None, + count: 0, + sum_scale: *sum_scale, + sum_precision: *sum_precision, + target_precision: *target_precision, + target_scale: *target_scale, + })), + + ( + Decimal256(sum_precision, sum_scale), + Decimal256(target_precision, target_scale), + ) => Ok(Box::new(DecimalAvgAccumulator:: { sum: None, count: 0, sum_scale: *sum_scale, @@ -156,7 +171,7 @@ impl AggregateExpr for Avg { Decimal128(_sum_precision, sum_scale), Decimal128(target_precision, target_scale), ) => { - let decimal_averager = Decimal128Averager::try_new( + let decimal_averager = DecimalAverager::::try_new( *sum_scale, *target_precision, *target_scale, @@ -172,6 +187,27 @@ impl AggregateExpr for Avg { ))) } + ( + Decimal256(_sum_precision, sum_scale), + Decimal256(target_precision, target_scale), + ) => { + let decimal_averager = DecimalAverager::::try_new( + *sum_scale, + *target_precision, + *target_scale, + )?; + + let avg_fn = move |sum: i256, count: u64| { + decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap()) + }; + + Ok(Box::new(AvgGroupsAccumulator::::new( + &self.input_data_type, + &self.result_data_type, + avg_fn, + ))) + } + _ => not_impl_err!( "AvgGroupsAccumulator for ({} --> {})", self.input_data_type, @@ -256,9 +292,8 @@ impl Accumulator for AvgAccumulator { } /// An accumulator to compute the average for decimals -#[derive(Debug)] -struct DecimalAvgAccumulator { - sum: Option, +struct DecimalAvgAccumulator { + sum: Option, count: u64, sum_scale: i8, sum_precision: u8, @@ -266,30 +301,46 @@ struct DecimalAvgAccumulator { target_scale: i8, } -impl Accumulator for DecimalAvgAccumulator { +impl Debug for DecimalAvgAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DecimalAvgAccumulator") + .field("sum", &self.sum) + .field("count", &self.count) + .field("sum_scale", &self.sum_scale) + .field("sum_precision", &self.sum_precision) + .field("target_precision", &self.target_precision) + .field("target_scale", &self.target_scale) + .finish() + } +} + +impl Accumulator for DecimalAvgAccumulator { fn state(&self) -> Result> { Ok(vec![ ScalarValue::from(self.count), - ScalarValue::Decimal128(self.sum, self.sum_precision, self.sum_scale), + ScalarValue::new_primitive::( + self.sum, + &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale), + )?, ]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); + let values = values[0].as_primitive::(); self.count += (values.len() - values.null_count()) as u64; if let Some(x) = sum(values) { - let v = self.sum.get_or_insert(0); - *v += x; + let v = self.sum.get_or_insert(T::Native::default()); + self.sum = Some(v.add_wrapping(x)); } Ok(()) } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); + let values = values[0].as_primitive::(); self.count -= (values.len() - values.null_count()) as u64; if let Some(x) = sum(values) { - self.sum = Some(self.sum.unwrap() - x); + self.sum = Some(self.sum.unwrap().sub_wrapping(x)); } Ok(()) } @@ -299,9 +350,9 @@ impl Accumulator for DecimalAvgAccumulator { self.count += sum(states[0].as_primitive::()).unwrap_or_default(); // sums are summed - if let Some(x) = sum(states[1].as_primitive::()) { - let v = self.sum.get_or_insert(0); - *v += x; + if let Some(x) = sum(states[1].as_primitive::()) { + let v = self.sum.get_or_insert(T::Native::default()); + self.sum = Some(v.add_wrapping(x)); } Ok(()) } @@ -310,20 +361,19 @@ impl Accumulator for DecimalAvgAccumulator { let v = self .sum .map(|v| { - Decimal128Averager::try_new( + DecimalAverager::::try_new( self.sum_scale, self.target_precision, self.target_scale, )? - .avg(v, self.count as _) + .avg(v, T::Native::from_usize(self.count as usize).unwrap()) }) .transpose()?; - Ok(ScalarValue::Decimal128( + ScalarValue::new_primitive::( v, - self.target_precision, - self.target_scale, - )) + &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale), + ) } fn supports_retract_batch(&self) -> bool { true diff --git a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs index 93b911c939d6..6c97d620616a 100644 --- a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs +++ b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs @@ -18,6 +18,7 @@ //! Defines BitAnd, BitOr, and BitXor Aggregate accumulators use ahash::RandomState; +use datafusion_common::cast::as_list_array; use std::any::Any; use std::sync::Arc; @@ -194,7 +195,7 @@ where } fn evaluate(&self) -> Result { - Ok(ScalarValue::new_primitive::(self.value, &T::DATA_TYPE)) + ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) } fn size(&self) -> usize { @@ -355,7 +356,7 @@ where } fn evaluate(&self) -> Result { - Ok(ScalarValue::new_primitive::(self.value, &T::DATA_TYPE)) + ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) } fn size(&self) -> usize { @@ -516,7 +517,7 @@ where } fn evaluate(&self) -> Result { - Ok(ScalarValue::new_primitive::(self.value, &T::DATA_TYPE)) + ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) } fn size(&self) -> usize { @@ -641,9 +642,10 @@ where .values .iter() .map(|x| ScalarValue::new_primitive::(Some(*x), &T::DATA_TYPE)) - .collect(); + .collect::>>()?; - vec![ScalarValue::new_list(Some(values), T::DATA_TYPE)] + let arr = ScalarValue::new_list(&values, &T::DATA_TYPE); + vec![ScalarValue::List(arr)] }; Ok(state_out) } @@ -668,12 +670,11 @@ where } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - for x in states[0].as_list::().iter().flatten() { - self.update_batch(&[x])? + if let Some(state) = states.first() { + let list_arr = as_list_array(state)?; + for arr in list_arr.iter().flatten() { + self.update_batch(&[arr])?; + } } Ok(()) } @@ -684,7 +685,7 @@ where acc = acc ^ *distinct_value; } let v = (!self.values.is_empty()).then_some(acc); - Ok(ScalarValue::new_primitive::(v, &T::DATA_TYPE)) + ScalarValue::new_primitive::(v, &T::DATA_TYPE) } fn size(&self) -> usize { diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 6568457bc234..c40f0db19405 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -114,13 +114,16 @@ pub fn create_aggregate_expr( ), (AggregateFunction::ArrayAgg, false) => { let expr = input_phy_exprs[0].clone(); + let nullable = expr.nullable(input_schema)?; + if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new(expr, name, data_type)) + Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable)) } else { Arc::new(expressions::OrderSensitiveArrayAgg::new( expr, name, data_type, + nullable, ordering_types, ordering_req.to_vec(), )) @@ -132,10 +135,13 @@ pub fn create_aggregate_expr( "ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available" ); } + let expr = input_phy_exprs[0].clone(); + let is_expr_nullable = expr.nullable(input_schema)?; Arc::new(expressions::DistinctArrayAgg::new( - input_phy_exprs[0].clone(), + expr, name, data_type, + is_expr_nullable, )) } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( @@ -363,6 +369,22 @@ pub fn create_aggregate_expr( ordering_req.to_vec(), ordering_types, )), + (AggregateFunction::StringAgg, false) => { + if !ordering_req.is_empty() { + return not_impl_err!( + "STRING_AGG(ORDER BY a ASC) order-sensitive aggregations are not available" + ); + } + Arc::new(expressions::StringAgg::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + data_type, + )) + } + (AggregateFunction::StringAgg, true) => { + return not_impl_err!("STRING_AGG(DISTINCT) aggregations are not available"); + } }) } @@ -432,8 +454,8 @@ mod tests { assert_eq!( Field::new_list( "c1", - Field::new("item", data_type.clone(), true,), - false, + Field::new("item", data_type.clone(), true), + true, ), result_agg_phy_exprs.field().unwrap() ); @@ -471,8 +493,8 @@ mod tests { assert_eq!( Field::new_list( "c1", - Field::new("item", data_type.clone(), true,), - false, + Field::new("item", data_type.clone(), true), + true, ), result_agg_phy_exprs.field().unwrap() ); diff --git a/datafusion/physical-expr/src/aggregate/correlation.rs b/datafusion/physical-expr/src/aggregate/correlation.rs index 475bfa4ce0da..61f2db5c8ef9 100644 --- a/datafusion/physical-expr/src/aggregate/correlation.rs +++ b/datafusion/physical-expr/src/aggregate/correlation.rs @@ -505,13 +505,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 738ca4e915f7..8e9ae5cea36b 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -123,7 +123,7 @@ impl GroupsAccumulator for CountGroupsAccumulator { self.counts.resize(total_num_groups, 0); accumulate_indices( group_indices, - values.nulls(), // ignore values + values.logical_nulls().as_ref(), opt_filter, |group_index| { self.counts[group_index] += 1; @@ -198,16 +198,18 @@ fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { if values.len() > 1 { let result_bool_buf: Option = values .iter() - .map(|a| a.nulls()) + .map(|a| a.logical_nulls()) .fold(None, |acc, b| match (acc, b) { (Some(acc), Some(b)) => Some(acc.bitand(b.inner())), (Some(acc), None) => Some(acc), - (None, Some(b)) => Some(b.inner().clone()), + (None, Some(b)) => Some(b.into_inner()), _ => None, }); result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits()) } else { - values[0].null_count() + values[0] + .logical_nulls() + .map_or(0, |nulls| nulls.null_count()) } } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 05be8cbccb5f..f7c13948b2dc 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -15,20 +15,32 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, TimeUnit}; +use arrow_array::types::{ + ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow_array::PrimitiveArray; + use std::any::Any; +use std::cmp::Eq; use std::fmt::Debug; +use std::hash::Hash; use std::sync::Arc; use ahash::RandomState; use arrow::array::{Array, ArrayRef}; use std::collections::HashSet; -use crate::aggregate::utils::down_cast_any_ref; +use crate::aggregate::utils::{down_cast_any_ref, Hashable}; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -use datafusion_common::ScalarValue; -use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_common::cast::{as_list_array, as_primitive_array}; +use datafusion_common::utils::array_into_list_array; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; type DistinctScalarValues = ScalarValue; @@ -59,6 +71,18 @@ impl DistinctCount { } } +macro_rules! native_distinct_count_accumulator { + ($TYPE:ident) => {{ + Ok(Box::new(NativeDistinctCountAccumulator::<$TYPE>::new())) + }}; +} + +macro_rules! float_distinct_count_accumulator { + ($TYPE:ident) => {{ + Ok(Box::new(FloatDistinctCountAccumulator::<$TYPE>::new())) + }}; +} + impl AggregateExpr for DistinctCount { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -82,10 +106,57 @@ impl AggregateExpr for DistinctCount { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: self.state_data_type.clone(), - })) + use DataType::*; + use TimeUnit::*; + + match &self.state_data_type { + Int8 => native_distinct_count_accumulator!(Int8Type), + Int16 => native_distinct_count_accumulator!(Int16Type), + Int32 => native_distinct_count_accumulator!(Int32Type), + Int64 => native_distinct_count_accumulator!(Int64Type), + UInt8 => native_distinct_count_accumulator!(UInt8Type), + UInt16 => native_distinct_count_accumulator!(UInt16Type), + UInt32 => native_distinct_count_accumulator!(UInt32Type), + UInt64 => native_distinct_count_accumulator!(UInt64Type), + Decimal128(_, _) => native_distinct_count_accumulator!(Decimal128Type), + Decimal256(_, _) => native_distinct_count_accumulator!(Decimal256Type), + + Date32 => native_distinct_count_accumulator!(Date32Type), + Date64 => native_distinct_count_accumulator!(Date64Type), + Time32(Millisecond) => { + native_distinct_count_accumulator!(Time32MillisecondType) + } + Time32(Second) => { + native_distinct_count_accumulator!(Time32SecondType) + } + Time64(Microsecond) => { + native_distinct_count_accumulator!(Time64MicrosecondType) + } + Time64(Nanosecond) => { + native_distinct_count_accumulator!(Time64NanosecondType) + } + Timestamp(Microsecond, _) => { + native_distinct_count_accumulator!(TimestampMicrosecondType) + } + Timestamp(Millisecond, _) => { + native_distinct_count_accumulator!(TimestampMillisecondType) + } + Timestamp(Nanosecond, _) => { + native_distinct_count_accumulator!(TimestampNanosecondType) + } + Timestamp(Second, _) => { + native_distinct_count_accumulator!(TimestampSecondType) + } + + Float16 => float_distinct_count_accumulator!(Float16Type), + Float32 => float_distinct_count_accumulator!(Float32Type), + Float64 => float_distinct_count_accumulator!(Float64Type), + + _ => Ok(Box::new(DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: self.state_data_type.clone(), + })), + } } fn name(&self) -> &str { @@ -142,23 +213,21 @@ impl DistinctCountAccumulator { impl Accumulator for DistinctCountAccumulator { fn state(&self) -> Result> { - let mut cols_out = - ScalarValue::new_list(Some(Vec::new()), self.state_data_type.clone()); - self.values - .iter() - .enumerate() - .for_each(|(_, distinct_values)| { - if let ScalarValue::List(Some(ref mut v), _) = cols_out { - v.push(distinct_values.clone()); - } - }); - Ok(vec![cols_out]) + let scalars = self.values.iter().cloned().collect::>(); + let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); + Ok(vec![ScalarValue::List(arr)]) } + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { if values.is_empty() { return Ok(()); } + let arr = &values[0]; + if arr.data_type() == &DataType::Null { + return Ok(()); + } + (0..arr.len()).try_for_each(|index| { if !arr.is_null(index) { let scalar = ScalarValue::try_from_array(arr, index)?; @@ -167,25 +236,17 @@ impl Accumulator for DistinctCountAccumulator { Ok(()) }) } + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); } - let arr = &states[0]; - (0..arr.len()).try_for_each(|index| { - let scalar = ScalarValue::try_from_array(arr, index)?; - - if let ScalarValue::List(Some(scalar), _) = scalar { - scalar.iter().for_each(|scalar| { - if !ScalarValue::is_null(scalar) { - self.values.insert(scalar.clone()); - } - }); - } else { - return internal_err!("Unexpected accumulator state"); - } - Ok(()) - }) + assert_eq!(states.len(), 1, "array_agg states must be singleton!"); + let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?; + for scalars in scalar_vec.into_iter() { + self.values.extend(scalars) + } + Ok(()) } fn evaluate(&self) -> Result { @@ -201,6 +262,182 @@ impl Accumulator for DistinctCountAccumulator { } } +#[derive(Debug)] +struct NativeDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, + T::Native: Eq + Hash, +{ + values: HashSet, +} + +impl NativeDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, + T::Native: Eq + Hash, +{ + fn new() -> Self { + Self { + values: HashSet::default(), + } + } +} + +impl Accumulator for NativeDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send + Debug, + T::Native: Eq + Hash, +{ + fn state(&self) -> Result> { + let arr = Arc::new(PrimitiveArray::::from_iter_values( + self.values.iter().cloned(), + )) as ArrayRef; + let list = Arc::new(array_into_list_array(arr)) as ArrayRef; + Ok(vec![ScalarValue::List(list)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = as_primitive_array::(&values[0])?; + arr.iter().for_each(|value| { + if let Some(value) = value { + self.values.insert(value); + } + }); + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + let list = as_primitive_array::(&list)?; + self.values.extend(list.values()) + }; + Ok(()) + }) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) + / 7) + .next_power_of_two(); + + // Size of accumulator + // + size of entry * number of buckets + // + 1 byte for each bucket + // + fixed size of HashSet + std::mem::size_of_val(self) + + std::mem::size_of::() * estimated_buckets + + estimated_buckets + + std::mem::size_of_val(&self.values) + } +} + +#[derive(Debug)] +struct FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, +{ + values: HashSet, RandomState>, +} + +impl FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, +{ + fn new() -> Self { + Self { + values: HashSet::default(), + } + } +} + +impl Accumulator for FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send + Debug, +{ + fn state(&self) -> Result> { + let arr = Arc::new(PrimitiveArray::::from_iter_values( + self.values.iter().map(|v| v.0), + )) as ArrayRef; + let list = Arc::new(array_into_list_array(arr)) as ArrayRef; + Ok(vec![ScalarValue::List(list)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = as_primitive_array::(&values[0])?; + arr.iter().for_each(|value| { + if let Some(value) = value { + self.values.insert(Hashable(value)); + } + }); + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + let list = as_primitive_array::(&list)?; + self.values + .extend(list.values().iter().map(|v| Hashable(*v))); + }; + Ok(()) + }) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) + / 7) + .next_power_of_two(); + + // Size of accumulator + // + size of entry * number of buckets + // + 1 byte for each bucket + // + fixed size of HashSet + std::mem::size_of_val(self) + + std::mem::size_of::() * estimated_buckets + + estimated_buckets + + std::mem::size_of_val(&self.values) + } +} + #[cfg(test)] mod tests { use crate::expressions::NoOp; @@ -211,33 +448,23 @@ mod tests { Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow::datatypes::DataType; + use arrow::datatypes::{ + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, + }; + use arrow_array::Decimal256Array; + use arrow_buffer::i256; + use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; use datafusion_common::internal_err; - - macro_rules! state_to_vec { - ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{ - match $LIST { - ScalarValue::List(_, field) => match field.data_type() { - &DataType::$DATA_TYPE => (), - _ => panic!("Unexpected DataType for list"), - }, - _ => panic!("Expected a ScalarValue::List"), - } - - match $LIST { - ScalarValue::List(None, _) => None, - ScalarValue::List(Some(scalar_values), _) => { - let vec = scalar_values - .iter() - .map(|scalar_value| match scalar_value { - ScalarValue::$DATA_TYPE(value) => *value, - _ => panic!("Unexpected ScalarValue variant"), - }) - .collect::>>(); - - Some(vec) - } - _ => unreachable!(), - } + use datafusion_common::DataFusionError; + + macro_rules! state_to_vec_primitive { + ($LIST:expr, $DATA_TYPE:ident) => {{ + let arr = ScalarValue::raw_data($LIST).unwrap(); + let list_arr = as_list_array(&arr).unwrap(); + let arr = list_arr.values(); + let arr = as_primitive_array::<$DATA_TYPE>(arr)?; + arr.values().iter().cloned().collect::>() }}; } @@ -259,18 +486,25 @@ mod tests { let (states, result) = run_update_batch(&arrays)?; - let mut state_vec = - state_to_vec!(&states[0], $DATA_TYPE, $PRIM_TYPE).unwrap(); + let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); state_vec.sort(); assert_eq!(states.len(), 1); - assert_eq!(state_vec, vec![Some(1), Some(2), Some(3)]); + assert_eq!(state_vec, vec![1, 2, 3]); assert_eq!(result, ScalarValue::Int64(Some(3))); Ok(()) }}; } + fn state_to_vec_bool(sv: &ScalarValue) -> Result> { + let arr = ScalarValue::raw_data(sv)?; + let list_arr = as_list_array(&arr)?; + let arr = list_arr.values(); + let bool_arr = as_boolean_array(arr)?; + Ok(bool_arr.iter().flatten().collect()) + } + fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec, ScalarValue)> { let agg = DistinctCount::new( arrays[0].data_type().clone(), @@ -353,13 +587,11 @@ mod tests { let (states, result) = run_update_batch(&arrays)?; - let mut state_vec = - state_to_vec!(&states[0], $DATA_TYPE, $PRIM_TYPE).unwrap(); + let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); dbg!(&state_vec); state_vec.sort_by(|a, b| match (a, b) { - (Some(lhs), Some(rhs)) => lhs.total_cmp(rhs), - _ => a.partial_cmp(b).unwrap(), + (lhs, rhs) => lhs.total_cmp(rhs), }); let nan_idx = state_vec.len() - 1; @@ -367,79 +599,114 @@ mod tests { assert_eq!( &state_vec[..nan_idx], vec![ - Some(<$PRIM_TYPE>::NEG_INFINITY), - Some(-4.5), - Some(<$PRIM_TYPE as SubNormal>::SUBNORMAL), - Some(1.0), - Some(2.0), - Some(3.0), - Some(<$PRIM_TYPE>::INFINITY) + <$PRIM_TYPE>::NEG_INFINITY, + -4.5, + <$PRIM_TYPE as SubNormal>::SUBNORMAL, + 1.0, + 2.0, + 3.0, + <$PRIM_TYPE>::INFINITY ] ); - assert!(state_vec[nan_idx].unwrap_or_default().is_nan()); + assert!(state_vec[nan_idx].is_nan()); assert_eq!(result, ScalarValue::Int64(Some(8))); Ok(()) }}; } + macro_rules! test_count_distinct_update_batch_bigint { + ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ + let values: Vec> = vec![ + Some(i256::from(1)), + Some(i256::from(1)), + None, + Some(i256::from(3)), + Some(i256::from(2)), + None, + Some(i256::from(2)), + Some(i256::from(3)), + Some(i256::from(1)), + ]; + + let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; + + let (states, result) = run_update_batch(&arrays)?; + + let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); + state_vec.sort(); + + assert_eq!(states.len(), 1); + assert_eq!(state_vec, vec![i256::from(1), i256::from(2), i256::from(3)]); + assert_eq!(result, ScalarValue::Int64(Some(3))); + + Ok(()) + }}; + } + #[test] fn count_distinct_update_batch_i8() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int8Array, Int8, i8) + test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8) } #[test] fn count_distinct_update_batch_i16() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int16Array, Int16, i16) + test_count_distinct_update_batch_numeric!(Int16Array, Int16Type, i16) } #[test] fn count_distinct_update_batch_i32() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int32Array, Int32, i32) + test_count_distinct_update_batch_numeric!(Int32Array, Int32Type, i32) } #[test] fn count_distinct_update_batch_i64() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int64Array, Int64, i64) + test_count_distinct_update_batch_numeric!(Int64Array, Int64Type, i64) } #[test] fn count_distinct_update_batch_u8() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt8Array, UInt8, u8) + test_count_distinct_update_batch_numeric!(UInt8Array, UInt8Type, u8) } #[test] fn count_distinct_update_batch_u16() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt16Array, UInt16, u16) + test_count_distinct_update_batch_numeric!(UInt16Array, UInt16Type, u16) } #[test] fn count_distinct_update_batch_u32() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt32Array, UInt32, u32) + test_count_distinct_update_batch_numeric!(UInt32Array, UInt32Type, u32) } #[test] fn count_distinct_update_batch_u64() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt64Array, UInt64, u64) + test_count_distinct_update_batch_numeric!(UInt64Array, UInt64Type, u64) } #[test] fn count_distinct_update_batch_f32() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float32Array, Float32, f32) + test_count_distinct_update_batch_floating_point!(Float32Array, Float32Type, f32) } #[test] fn count_distinct_update_batch_f64() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float64Array, Float64, f64) + test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64) + } + + #[test] + fn count_distinct_update_batch_i256() -> Result<()> { + test_count_distinct_update_batch_bigint!(Decimal256Array, Decimal256Type, i256) } #[test] fn count_distinct_update_batch_boolean() -> Result<()> { - let get_count = |data: BooleanArray| -> Result<(Vec>, i64)> { + let get_count = |data: BooleanArray| -> Result<(Vec, i64)> { let arrays = vec![Arc::new(data) as ArrayRef]; let (states, result) = run_update_batch(&arrays)?; - let mut state_vec = state_to_vec!(&states[0], Boolean, bool).unwrap(); + let mut state_vec = state_to_vec_bool(&states[0])?; state_vec.sort(); + let count = match result { ScalarValue::Int64(c) => c.ok_or_else(|| { DataFusionError::Internal("Found None count".to_string()) @@ -467,22 +734,13 @@ mod tests { Some(false), ]); - assert_eq!( - get_count(zero_count_values)?, - (Vec::>::new(), 0) - ); - assert_eq!(get_count(one_count_values)?, (vec![Some(false)], 1)); - assert_eq!( - get_count(one_count_values_with_null)?, - (vec![Some(true)], 1) - ); - assert_eq!( - get_count(two_count_values)?, - (vec![Some(false), Some(true)], 2) - ); + assert_eq!(get_count(zero_count_values)?, (Vec::::new(), 0)); + assert_eq!(get_count(one_count_values)?, (vec![false], 1)); + assert_eq!(get_count(one_count_values_with_null)?, (vec![true], 1)); + assert_eq!(get_count(two_count_values)?, (vec![false, true], 2)); assert_eq!( get_count(two_count_values_with_null)?, - (vec![Some(false), Some(true)], 2) + (vec![false, true], 2) ); Ok(()) } @@ -494,9 +752,9 @@ mod tests { )) as ArrayRef]; let (states, result) = run_update_batch(&arrays)?; - + let state_vec = state_to_vec_primitive!(&states[0], Int32Type); assert_eq!(states.len(), 1); - assert_eq!(state_to_vec!(&states[0], Int32, i32), Some(vec![])); + assert!(state_vec.is_empty()); assert_eq!(result, ScalarValue::Int64(Some(0))); Ok(()) @@ -507,9 +765,9 @@ mod tests { let arrays = vec![Arc::new(Int32Array::from(vec![0_i32; 0])) as ArrayRef]; let (states, result) = run_update_batch(&arrays)?; - + let state_vec = state_to_vec_primitive!(&states[0], Int32Type); assert_eq!(states.len(), 1); - assert_eq!(state_to_vec!(&states[0], Int32, i32), Some(vec![])); + assert!(state_vec.is_empty()); assert_eq!(result, ScalarValue::Int64(Some(0))); Ok(()) diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs b/datafusion/physical-expr/src/aggregate/covariance.rs index 5e589d4e39fd..0f838eb6fa1c 100644 --- a/datafusion/physical-expr/src/aggregate/covariance.rs +++ b/datafusion/physical-expr/src/aggregate/covariance.rs @@ -754,13 +754,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 6ae7b4895ad6..4afa8d0dd5ec 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -20,29 +20,30 @@ use std::any::Any; use std::sync::Arc; -use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; +use crate::aggregate::utils::{down_cast_any_ref, get_sort_options, ordering_fields}; use crate::expressions::format_state_name; -use crate::{AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr}; +use crate::{ + reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, +}; -use arrow::array::ArrayRef; -use arrow::compute; -use arrow::compute::{lexsort_to_indices, SortColumn}; +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; +use arrow::compute::{self, lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field}; -use arrow_array::cast::AsArray; -use arrow_array::{Array, BooleanArray}; -use arrow_schema::SortOptions; use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::Accumulator; /// FIRST_VALUE aggregate expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct FirstValue { name: String, input_data_type: DataType, order_by_data_types: Vec, expr: Arc, ordering_req: LexOrdering, + requirement_satisfied: bool, } impl FirstValue { @@ -54,14 +55,68 @@ impl FirstValue { ordering_req: LexOrdering, order_by_data_types: Vec, ) -> Self { + let requirement_satisfied = ordering_req.is_empty(); Self { name: name.into(), input_data_type, order_by_data_types, expr, ordering_req, + requirement_satisfied, } } + + /// Returns the name of the aggregate expression. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the input data type of the aggregate expression. + pub fn input_data_type(&self) -> &DataType { + &self.input_data_type + } + + /// Returns the data types of the order-by columns. + pub fn order_by_data_types(&self) -> &Vec { + &self.order_by_data_types + } + + /// Returns the expression associated with the aggregate function. + pub fn expr(&self) -> &Arc { + &self.expr + } + + /// Returns the lexical ordering requirements of the aggregate expression. + pub fn ordering_req(&self) -> &LexOrdering { + &self.ordering_req + } + + pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } + + pub fn convert_to_last(self) -> LastValue { + let name = if self.name.starts_with("FIRST") { + format!("LAST{}", &self.name[5..]) + } else { + format!("LAST_VALUE({})", self.expr) + }; + let FirstValue { + expr, + input_data_type, + ordering_req, + order_by_data_types, + .. + } = self; + LastValue::new( + expr, + name, + input_data_type, + reverse_order_bys(&ordering_req), + order_by_data_types, + ) + } } impl AggregateExpr for FirstValue { @@ -75,11 +130,14 @@ impl AggregateExpr for FirstValue { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(FirstValueAccumulator::try_new( + FirstValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } fn state_fields(&self) -> Result> { @@ -105,11 +163,7 @@ impl AggregateExpr for FirstValue { } fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - if self.ordering_req.is_empty() { - None - } else { - Some(&self.ordering_req) - } + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) } fn name(&self) -> &str { @@ -117,26 +171,18 @@ impl AggregateExpr for FirstValue { } fn reverse_expr(&self) -> Option> { - let name = if self.name.starts_with("FIRST") { - format!("LAST{}", &self.name[5..]) - } else { - format!("LAST_VALUE({})", self.expr) - }; - Some(Arc::new(LastValue::new( - self.expr.clone(), - name, - self.input_data_type.clone(), - self.ordering_req.clone(), - self.order_by_data_types.clone(), - ))) + Some(Arc::new(self.clone().convert_to_last())) } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(FirstValueAccumulator::try_new( + FirstValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } } @@ -165,6 +211,8 @@ struct FirstValueAccumulator { orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, + // Stores whether incoming data already satisfies the ordering requirement. + requirement_satisfied: bool, } impl FirstValueAccumulator { @@ -178,11 +226,13 @@ impl FirstValueAccumulator { .iter() .map(ScalarValue::try_from) .collect::>>()?; - ScalarValue::try_from(data_type).map(|value| Self { - first: value, + let requirement_satisfied = ordering_req.is_empty(); + ScalarValue::try_from(data_type).map(|first| Self { + first, is_set: false, orderings, ordering_req, + requirement_satisfied, }) } @@ -192,6 +242,31 @@ impl FirstValueAccumulator { self.orderings = row[1..].to_vec(); self.is_set = true; } + + fn get_first_idx(&self, values: &[ArrayRef]) -> Result> { + let [value, ordering_values @ ..] = values else { + return internal_err!("Empty row in FIRST_VALUE"); + }; + if self.requirement_satisfied { + // Get first entry according to the pre-existing ordering (0th index): + return Ok((!value.is_empty()).then_some(0)); + } + let sort_columns = ordering_values + .iter() + .zip(self.ordering_req.iter()) + .map(|(values, req)| SortColumn { + values: values.clone(), + options: Some(req.options), + }) + .collect::>(); + let indices = lexsort_to_indices(&sort_columns, Some(1))?; + Ok((!indices.is_empty()).then_some(indices.value(0) as _)) + } + + fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } } impl Accumulator for FirstValueAccumulator { @@ -203,11 +278,25 @@ impl Accumulator for FirstValueAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - // If we have seen first value, we shouldn't update it - if !values[0].is_empty() && !self.is_set { - let row = get_row_at_idx(values, 0)?; - // Update with first value in the array. - self.update_with_new_row(&row); + if !self.is_set { + if let Some(first_idx) = self.get_first_idx(values)? { + let row = get_row_at_idx(values, first_idx)?; + self.update_with_new_row(&row); + } + } else if !self.requirement_satisfied { + if let Some(first_idx) = self.get_first_idx(values)? { + let row = get_row_at_idx(values, first_idx)?; + let orderings = &row[1..]; + if compare_rows( + &self.orderings, + orderings, + &get_sort_options(&self.ordering_req), + )? + .is_gt() + { + self.update_with_new_row(&row); + } + } } Ok(()) } @@ -236,7 +325,7 @@ impl Accumulator for FirstValueAccumulator { let sort_options = get_sort_options(&self.ordering_req); // Either there is no existing value, or there is an earlier version in new data. if !self.is_set - || compare_rows(first_ordering, &self.orderings, &sort_options)?.is_lt() + || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt() { // Update with first value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state @@ -260,13 +349,14 @@ impl Accumulator for FirstValueAccumulator { } /// LAST_VALUE aggregate expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct LastValue { name: String, input_data_type: DataType, order_by_data_types: Vec, expr: Arc, ordering_req: LexOrdering, + requirement_satisfied: bool, } impl LastValue { @@ -278,14 +368,68 @@ impl LastValue { ordering_req: LexOrdering, order_by_data_types: Vec, ) -> Self { + let requirement_satisfied = ordering_req.is_empty(); Self { name: name.into(), input_data_type, order_by_data_types, expr, ordering_req, + requirement_satisfied, } } + + /// Returns the name of the aggregate expression. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the input data type of the aggregate expression. + pub fn input_data_type(&self) -> &DataType { + &self.input_data_type + } + + /// Returns the data types of the order-by columns. + pub fn order_by_data_types(&self) -> &Vec { + &self.order_by_data_types + } + + /// Returns the expression associated with the aggregate function. + pub fn expr(&self) -> &Arc { + &self.expr + } + + /// Returns the lexical ordering requirements of the aggregate expression. + pub fn ordering_req(&self) -> &LexOrdering { + &self.ordering_req + } + + pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } + + pub fn convert_to_first(self) -> FirstValue { + let name = if self.name.starts_with("LAST") { + format!("FIRST{}", &self.name[4..]) + } else { + format!("FIRST_VALUE({})", self.expr) + }; + let LastValue { + expr, + input_data_type, + ordering_req, + order_by_data_types, + .. + } = self; + FirstValue::new( + expr, + name, + input_data_type, + reverse_order_bys(&ordering_req), + order_by_data_types, + ) + } } impl AggregateExpr for LastValue { @@ -299,11 +443,14 @@ impl AggregateExpr for LastValue { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(LastValueAccumulator::try_new( + LastValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } fn state_fields(&self) -> Result> { @@ -329,11 +476,7 @@ impl AggregateExpr for LastValue { } fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - if self.ordering_req.is_empty() { - None - } else { - Some(&self.ordering_req) - } + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) } fn name(&self) -> &str { @@ -341,26 +484,18 @@ impl AggregateExpr for LastValue { } fn reverse_expr(&self) -> Option> { - let name = if self.name.starts_with("LAST") { - format!("FIRST{}", &self.name[4..]) - } else { - format!("FIRST_VALUE({})", self.expr) - }; - Some(Arc::new(FirstValue::new( - self.expr.clone(), - name, - self.input_data_type.clone(), - self.ordering_req.clone(), - self.order_by_data_types.clone(), - ))) + Some(Arc::new(self.clone().convert_to_first())) } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(LastValueAccumulator::try_new( + LastValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } } @@ -388,6 +523,8 @@ struct LastValueAccumulator { orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, + // Stores whether incoming data already satisfies the ordering requirement. + requirement_satisfied: bool, } impl LastValueAccumulator { @@ -401,11 +538,13 @@ impl LastValueAccumulator { .iter() .map(ScalarValue::try_from) .collect::>>()?; - Ok(Self { - last: ScalarValue::try_from(data_type)?, + let requirement_satisfied = ordering_req.is_empty(); + ScalarValue::try_from(data_type).map(|last| Self { + last, is_set: false, orderings, ordering_req, + requirement_satisfied, }) } @@ -415,6 +554,35 @@ impl LastValueAccumulator { self.orderings = row[1..].to_vec(); self.is_set = true; } + + fn get_last_idx(&self, values: &[ArrayRef]) -> Result> { + let [value, ordering_values @ ..] = values else { + return internal_err!("Empty row in LAST_VALUE"); + }; + if self.requirement_satisfied { + // Get last entry according to the order of data: + return Ok((!value.is_empty()).then_some(value.len() - 1)); + } + let sort_columns = ordering_values + .iter() + .zip(self.ordering_req.iter()) + .map(|(values, req)| { + // Take the reverse ordering requirement. This enables us to + // use "fetch = 1" to get the last value. + SortColumn { + values: values.clone(), + options: Some(!req.options), + } + }) + .collect::>(); + let indices = lexsort_to_indices(&sort_columns, Some(1))?; + Ok((!indices.is_empty()).then_some(indices.value(0) as _)) + } + + fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } } impl Accumulator for LastValueAccumulator { @@ -426,11 +594,26 @@ impl Accumulator for LastValueAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if !values[0].is_empty() { - let row = get_row_at_idx(values, values[0].len() - 1)?; - // Update with last value in the array. - self.update_with_new_row(&row); + if !self.is_set || self.requirement_satisfied { + if let Some(last_idx) = self.get_last_idx(values)? { + let row = get_row_at_idx(values, last_idx)?; + self.update_with_new_row(&row); + } + } else if let Some(last_idx) = self.get_last_idx(values)? { + let row = get_row_at_idx(values, last_idx)?; + let orderings = &row[1..]; + // Update when there is a more recent entry + if compare_rows( + &self.orderings, + orderings, + &get_sort_options(&self.ordering_req), + )? + .is_lt() + { + self.update_with_new_row(&row); + } } + Ok(()) } @@ -461,7 +644,7 @@ impl Accumulator for LastValueAccumulator { // Either there is no existing value, or there is a newer (latest) // version in the new data: if !self.is_set - || compare_rows(last_ordering, &self.orderings, &sort_options)?.is_gt() + || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt() { // Update with last value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state @@ -492,7 +675,7 @@ fn filter_states_according_to_is_set( ) -> Result> { states .iter() - .map(|state| compute::filter(state, flags).map_err(DataFusionError::ArrowError)) + .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e))) .collect::>>() } @@ -510,26 +693,18 @@ fn convert_to_sort_cols( .collect::>() } -/// Selects the sort option attribute from all the given `PhysicalSortExpr`s. -fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec { - ordering_req - .iter() - .map(|item| item.options) - .collect::>() -} - #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::aggregate::first_last::{FirstValueAccumulator, LastValueAccumulator}; + use arrow::compute::concat; use arrow_array::{ArrayRef, Int64Array}; use arrow_schema::DataType; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; - use arrow::compute::concat; - use std::sync::Arc; - #[test] fn test_first_last_value_value() -> Result<()> { let mut first_accumulator = @@ -588,7 +763,10 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(concat(&[&state1[idx].to_array(), &state2[idx].to_array()])?); + states.push(concat(&[ + &state1[idx].to_array()?, + &state2[idx].to_array()?, + ])?); } let mut first_accumulator = @@ -615,7 +793,10 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(concat(&[&state1[idx].to_array(), &state2[idx].to_array()])?); + states.push(concat(&[ + &state1[idx].to_array()?, + &state2[idx].to_array()?, + ])?); } let mut last_accumulator = diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs index dcc8c37e7484..c6fd17a69b39 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs @@ -25,7 +25,8 @@ use arrow::{ }; use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray}; use datafusion_common::{ - utils::get_arrayref_at_indices, DataFusionError, Result, ScalarValue, + arrow_datafusion_err, utils::get_arrayref_at_indices, DataFusionError, Result, + ScalarValue, }; use datafusion_expr::Accumulator; @@ -309,7 +310,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { // double check each array has the same length (aka the // accumulator was implemented correctly - if let Some(first_col) = arrays.get(0) { + if let Some(first_col) = arrays.first() { for arr in &arrays { assert_eq!(arr.len(), first_col.len()) } @@ -372,7 +373,7 @@ fn get_filter_at_indices( ) }) .transpose() - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } // Copied from physical-plan @@ -394,7 +395,7 @@ pub(crate) fn slice_and_maybe_filter( sliced_arrays .iter() .map(|array| { - compute::filter(array, filter_array).map_err(DataFusionError::ArrowError) + compute::filter(array, filter_array).map_err(|e| arrow_datafusion_err!(e)) }) .collect() } else { diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs index 1ec412402638..691b1c1752f4 100644 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ b/datafusion/physical-expr/src/aggregate/median.rs @@ -150,10 +150,10 @@ impl Accumulator for MedianAccumulator { .all_values .iter() .map(|x| ScalarValue::new_primitive::(Some(*x), &self.data_type)) - .collect(); - let state = ScalarValue::new_list(Some(all_values), self.data_type.clone()); + .collect::>>()?; - Ok(vec![state]) + let arr = ScalarValue::new_list(&all_values, &self.data_type); + Ok(vec![ScalarValue::List(arr)]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { @@ -188,7 +188,7 @@ impl Accumulator for MedianAccumulator { let (_, median, _) = d.select_nth_unstable_by(len / 2, cmp); Some(*median) }; - Ok(ScalarValue::new_primitive::(median, &self.data_type)) + ScalarValue::new_primitive::(median, &self.data_type) } fn size(&self) -> usize { diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index 5c4c48b15803..7e3ef2a2abab 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -53,6 +53,9 @@ use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use arrow::array::Array; use arrow::array::Decimal128Array; +use arrow::array::Decimal256Array; +use arrow::datatypes::i256; +use arrow::datatypes::Decimal256Type; use super::moving_min_max; @@ -183,6 +186,7 @@ impl AggregateExpr for Max { | Float32 | Float64 | Decimal128(_, _) + | Decimal256(_, _) | Date32 | Date64 | Time32(_) @@ -239,6 +243,9 @@ impl AggregateExpr for Max { Decimal128(_, _) => { instantiate_max_accumulator!(self, i128, Decimal128Type) } + Decimal256(_, _) => { + instantiate_max_accumulator!(self, i256, Decimal256Type) + } // It would be nice to have a fast implementation for Strings as well // https://github.com/apache/arrow-datafusion/issues/6906 @@ -318,6 +325,16 @@ macro_rules! min_max_batch { scale ) } + DataType::Decimal256(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal256Array, + Decimal256, + $OP, + precision, + scale + ) + } // all types that have a natural order DataType::Float64 => { typed_min_max_batch!($VALUES, Float64Array, Float64, $OP) @@ -522,6 +539,19 @@ macro_rules! min_max { ); } } + ( + lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => { typed_min_max!(lhs, rhs, Boolean, $OP) } @@ -880,6 +910,7 @@ impl AggregateExpr for Min { | Float32 | Float64 | Decimal128(_, _) + | Decimal256(_, _) | Date32 | Date64 | Time32(_) @@ -935,6 +966,9 @@ impl AggregateExpr for Min { Decimal128(_, _) => { instantiate_min_accumulator!(self, i128, Decimal128Type) } + Decimal256(_, _) => { + instantiate_min_accumulator!(self, i256, Decimal256Type) + } // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!( "GroupsAccumulator not supported for min({})", @@ -1263,12 +1297,7 @@ mod tests { #[test] fn max_utf8() -> Result<()> { let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::Utf8, - Max, - ScalarValue::Utf8(Some("d".to_string())) - ) + generic_test_op!(a, DataType::Utf8, Max, ScalarValue::from("d")) } #[test] @@ -1285,12 +1314,7 @@ mod tests { #[test] fn min_utf8() -> Result<()> { let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::Utf8, - Min, - ScalarValue::Utf8(Some("a".to_string())) - ) + generic_test_op!(a, DataType::Utf8, Min, ScalarValue::from("a")) } #[test] diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 4920f7a3e07f..5bd1fca385b1 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -15,16 +15,20 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::{FirstValue, LastValue, OrderSensitiveArrayAgg}; -use crate::{PhysicalExpr, PhysicalSortExpr}; -use arrow::datatypes::Field; -use datafusion_common::{not_impl_err, DataFusionError, Result}; -use datafusion_expr::Accumulator; use std::any::Any; use std::fmt::Debug; use std::sync::Arc; use self::groups_accumulator::GroupsAccumulator; +use crate::expressions::OrderSensitiveArrayAgg; +use crate::{PhysicalExpr, PhysicalSortExpr}; + +use arrow::datatypes::Field; +use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_expr::Accumulator; + +mod hyperloglog; +mod tdigest; pub(crate) mod approx_distinct; pub(crate) mod approx_median; @@ -43,21 +47,21 @@ pub(crate) mod covariance; pub(crate) mod first_last; pub(crate) mod grouping; pub(crate) mod median; +pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; -pub mod build_in; pub(crate) mod groups_accumulator; -mod hyperloglog; -pub mod moving_min_max; pub(crate) mod regr; pub(crate) mod stats; pub(crate) mod stddev; pub(crate) mod sum; pub(crate) mod sum_distinct; -mod tdigest; -pub mod utils; pub(crate) mod variance; +pub mod build_in; +pub mod moving_min_max; +pub mod utils; + /// An aggregate expression that: /// * knows its resulting field /// * knows how to create its accumulator @@ -68,7 +72,7 @@ pub(crate) mod variance; /// `PartialEq` to allows comparing equality between the /// trait objects. pub trait AggregateExpr: Send + Sync + Debug + PartialEq { - /// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be + /// Returns the aggregate expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -133,10 +137,7 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { /// Checks whether the given aggregate expression is order-sensitive. /// For instance, a `SUM` aggregation doesn't depend on the order of its inputs. -/// However, a `FirstValue` depends on the input ordering (if the order changes, -/// the first value in the list would change). +/// However, an `ARRAY_AGG` with `ORDER BY` depends on the input ordering. pub fn is_order_sensitive(aggr_expr: &Arc) -> bool { - aggr_expr.as_any().is::() - || aggr_expr.as_any().is::() - || aggr_expr.as_any().is::() + aggr_expr.as_any().is::() } diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs index 330507d6ffa6..64e19ef502c7 100644 --- a/datafusion/physical-expr/src/aggregate/stddev.rs +++ b/datafusion/physical-expr/src/aggregate/stddev.rs @@ -445,13 +445,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs new file mode 100644 index 000000000000..7adc736932ad --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -0,0 +1,246 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the `string_agg` function + +use crate::aggregate::utils::down_cast_any_ref; +use crate::expressions::{format_state_name, Literal}; +use crate::{AggregateExpr, PhysicalExpr}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::Accumulator; +use std::any::Any; +use std::sync::Arc; + +/// STRING_AGG aggregate expression +#[derive(Debug)] +pub struct StringAgg { + name: String, + data_type: DataType, + expr: Arc, + delimiter: Arc, + nullable: bool, +} + +impl StringAgg { + /// Create a new StringAgg aggregate function + pub fn new( + expr: Arc, + delimiter: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + data_type, + delimiter, + expr, + nullable: true, + } + } +} + +impl AggregateExpr for StringAgg { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new( + &self.name, + self.data_type.clone(), + self.nullable, + )) + } + + fn create_accumulator(&self) -> Result> { + if let Some(delimiter) = self.delimiter.as_any().downcast_ref::() { + match delimiter.value() { + ScalarValue::Utf8(Some(delimiter)) + | ScalarValue::LargeUtf8(Some(delimiter)) => { + return Ok(Box::new(StringAggAccumulator::new(delimiter))); + } + ScalarValue::Null => { + return Ok(Box::new(StringAggAccumulator::new(""))); + } + _ => return not_impl_err!("StringAgg not supported for {}", self.name), + } + } + not_impl_err!("StringAgg not supported for {}", self.name) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new( + format_state_name(&self.name, "string_agg"), + self.data_type.clone(), + self.nullable, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone(), self.delimiter.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for StringAgg { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.data_type == x.data_type + && self.expr.eq(&x.expr) + && self.delimiter.eq(&x.delimiter) + }) + .unwrap_or(false) + } +} + +#[derive(Debug)] +pub(crate) struct StringAggAccumulator { + values: Option, + delimiter: String, +} + +impl StringAggAccumulator { + pub fn new(delimiter: &str) -> Self { + Self { + values: None, + delimiter: delimiter.to_string(), + } + } +} + +impl Accumulator for StringAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let string_array: Vec<_> = as_generic_string_array::(&values[0])? + .iter() + .filter_map(|v| v.as_ref().map(ToString::to_string)) + .collect(); + if !string_array.is_empty() { + let s = string_array.join(self.delimiter.as_str()); + let v = self.values.get_or_insert("".to_string()); + if !v.is_empty() { + v.push_str(self.delimiter.as_str()); + } + v.push_str(s.as_str()); + } + Ok(()) + } + + fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.update_batch(values)?; + Ok(()) + } + + fn state(&self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::LargeUtf8(self.values.clone())) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) + + self.delimiter.capacity() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::tests::aggregate; + use crate::expressions::{col, create_aggregate_expr, try_cast}; + use arrow::array::ArrayRef; + use arrow::datatypes::*; + use arrow::record_batch::RecordBatch; + use arrow_array::LargeStringArray; + use arrow_array::StringArray; + use datafusion_expr::type_coercion::aggregates::coerce_types; + use datafusion_expr::AggregateFunction; + + fn assert_string_aggregate( + array: ArrayRef, + function: AggregateFunction, + distinct: bool, + expected: ScalarValue, + delimiter: String, + ) { + let data_type = array.data_type(); + let sig = function.signature(); + let coerced = + coerce_types(&function, &[data_type.clone(), DataType::Utf8], &sig).unwrap(); + + let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); + let batch = + RecordBatch::try_new(Arc::new(input_schema.clone()), vec![array]).unwrap(); + + let input = try_cast( + col("a", &input_schema).unwrap(), + &input_schema, + coerced[0].clone(), + ) + .unwrap(); + + let delimiter = Arc::new(Literal::new(ScalarValue::from(delimiter))); + let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); + let agg = create_aggregate_expr( + &function, + distinct, + &[input, delimiter], + &[], + &schema, + "agg", + ) + .unwrap(); + + let result = aggregate(&batch, agg).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn string_agg_utf8() { + let a: ArrayRef = Arc::new(StringArray::from(vec!["h", "e", "l", "l", "o"])); + assert_string_aggregate( + a, + AggregateFunction::StringAgg, + false, + ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())), + ",".to_owned(), + ); + } + + #[test] + fn string_agg_largeutf8() { + let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["h", "e", "l", "l", "o"])); + assert_string_aggregate( + a, + AggregateFunction::StringAgg, + false, + ScalarValue::LargeUtf8(Some("h|e|l|l|o".to_owned())), + "|".to_owned(), + ); + } +} diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 5cc8e933324e..03f666cc4e5d 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -41,7 +41,10 @@ use datafusion_expr::Accumulator; #[derive(Debug, Clone)] pub struct Sum { name: String, + // The DataType for the input expression data_type: DataType, + // The DataType for the final sum + return_type: DataType, expr: Arc, nullable: bool, } @@ -53,11 +56,12 @@ impl Sum { name: impl Into, data_type: DataType, ) -> Self { - let data_type = sum_return_type(&data_type).unwrap(); + let return_type = sum_return_type(&data_type).unwrap(); Self { name: name.into(), - expr, data_type, + return_type, + expr, nullable: true, } } @@ -70,13 +74,13 @@ impl Sum { /// `s` is a `Sum`, `helper` is a macro accepting (ArrowPrimitiveType, DataType) macro_rules! downcast_sum { ($s:ident, $helper:ident) => { - match $s.data_type { - DataType::UInt64 => $helper!(UInt64Type, $s.data_type), - DataType::Int64 => $helper!(Int64Type, $s.data_type), - DataType::Float64 => $helper!(Float64Type, $s.data_type), - DataType::Decimal128(_, _) => $helper!(Decimal128Type, $s.data_type), - DataType::Decimal256(_, _) => $helper!(Decimal256Type, $s.data_type), - _ => not_impl_err!("Sum not supported for {}: {}", $s.name, $s.data_type), + match $s.return_type { + DataType::UInt64 => $helper!(UInt64Type, $s.return_type), + DataType::Int64 => $helper!(Int64Type, $s.return_type), + DataType::Float64 => $helper!(Float64Type, $s.return_type), + DataType::Decimal128(_, _) => $helper!(Decimal128Type, $s.return_type), + DataType::Decimal256(_, _) => $helper!(Decimal256Type, $s.return_type), + _ => not_impl_err!("Sum not supported for {}: {}", $s.name, $s.return_type), } }; } @@ -91,7 +95,7 @@ impl AggregateExpr for Sum { fn field(&self) -> Result { Ok(Field::new( &self.name, - self.data_type.clone(), + self.return_type.clone(), self.nullable, )) } @@ -108,7 +112,7 @@ impl AggregateExpr for Sum { fn state_fields(&self) -> Result> { Ok(vec![Field::new( format_state_name(&self.name, "sum"), - self.data_type.clone(), + self.return_type.clone(), self.nullable, )]) } @@ -205,7 +209,7 @@ impl Accumulator for SumAccumulator { } fn evaluate(&self) -> Result { - Ok(ScalarValue::new_primitive::(self.sum, &self.data_type)) + ScalarValue::new_primitive::(self.sum, &self.data_type) } fn size(&self) -> usize { @@ -265,7 +269,7 @@ impl Accumulator for SlidingSumAccumulator { fn evaluate(&self) -> Result { let v = (self.count != 0).then_some(self.sum); - Ok(ScalarValue::new_primitive::(v, &self.data_type)) + ScalarValue::new_primitive::(v, &self.data_type) } fn size(&self) -> usize { diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index c3d8d5e87068..6dbb39224629 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -25,11 +25,11 @@ use arrow::array::{Array, ArrayRef}; use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::{ArrowNativeTypeOp, ArrowPrimitiveType}; -use arrow_buffer::{ArrowNativeType, ToByteSlice}; +use arrow_buffer::ArrowNativeType; use std::collections::HashSet; use crate::aggregate::sum::downcast_sum; -use crate::aggregate::utils::down_cast_any_ref; +use crate::aggregate::utils::{down_cast_any_ref, Hashable}; use crate::{AggregateExpr, PhysicalExpr}; use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::sum_return_type; @@ -40,8 +40,10 @@ use datafusion_expr::Accumulator; pub struct DistinctSum { /// Column name name: String, - /// The DataType for the final sum + // The DataType for the input expression data_type: DataType, + // The DataType for the final sum + return_type: DataType, /// The input arguments, only contains 1 item for sum exprs: Vec>, } @@ -53,10 +55,11 @@ impl DistinctSum { name: String, data_type: DataType, ) -> Self { - let data_type = sum_return_type(&data_type).unwrap(); + let return_type = sum_return_type(&data_type).unwrap(); Self { name, data_type, + return_type, exprs, } } @@ -68,14 +71,14 @@ impl AggregateExpr for DistinctSum { } fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) + Ok(Field::new(&self.name, self.return_type.clone(), true)) } fn state_fields(&self) -> Result> { // State field is a List which stores items to rebuild hash set. Ok(vec![Field::new_list( format_state_name(&self.name, "sum distinct"), - Field::new("item", self.data_type.clone(), true), + Field::new("item", self.return_type.clone(), true), false, )]) } @@ -116,24 +119,6 @@ impl PartialEq for DistinctSum { } } -/// A wrapper around a type to provide hash for floats -#[derive(Copy, Clone)] -struct Hashable(T); - -impl std::hash::Hash for Hashable { - fn hash(&self, state: &mut H) { - self.0.to_byte_slice().hash(state) - } -} - -impl PartialEq for Hashable { - fn eq(&self, other: &Self) -> bool { - self.0.is_eq(other.0) - } -} - -impl Eq for Hashable {} - struct DistinctSumAccumulator { values: HashSet, RandomState>, data_type: DataType, @@ -159,17 +144,18 @@ impl Accumulator for DistinctSumAccumulator { // 1. Stores aggregate state in `ScalarValue::List` // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set let state_out = { - let mut distinct_values = Vec::new(); - self.values.iter().for_each(|distinct_value| { - distinct_values.push(ScalarValue::new_primitive::( - Some(distinct_value.0), - &self.data_type, - )) - }); - vec![ScalarValue::new_list( - Some(distinct_values), - self.data_type.clone(), - )] + let distinct_values = self + .values + .iter() + .map(|value| { + ScalarValue::new_primitive::(Some(value.0), &self.data_type) + }) + .collect::>>()?; + + vec![ScalarValue::List(ScalarValue::new_list( + &distinct_values, + &self.data_type, + ))] }; Ok(state_out) } @@ -206,7 +192,7 @@ impl Accumulator for DistinctSumAccumulator { acc = acc.add_wrapping(distinct_value.0) } let v = (!self.values.is_empty()).then_some(acc); - Ok(ScalarValue::new_primitive::(v, &self.data_type)) + ScalarValue::new_primitive::(v, &self.data_type) } fn size(&self) -> usize { diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs b/datafusion/physical-expr/src/aggregate/tdigest.rs index 7e6d2dcf8f4f..90f5244f477d 100644 --- a/datafusion/physical-expr/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr/src/aggregate/tdigest.rs @@ -28,6 +28,9 @@ //! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h use arrow::datatypes::DataType; +use arrow_array::cast::as_list_array; +use arrow_array::types::Float64Type; +use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::ScalarValue; use std::cmp::Ordering; @@ -566,20 +569,22 @@ impl TDigest { /// [`TDigest`]. pub(crate) fn to_scalar_state(&self) -> Vec { // Gather up all the centroids - let centroids: Vec<_> = self + let centroids: Vec = self .centroids .iter() .flat_map(|c| [c.mean(), c.weight()]) .map(|v| ScalarValue::Float64(Some(v))) .collect(); + let arr = ScalarValue::new_list(¢roids, &DataType::Float64); + vec![ ScalarValue::UInt64(Some(self.max_size as u64)), ScalarValue::Float64(Some(self.sum)), ScalarValue::Float64(Some(self.count)), ScalarValue::Float64(Some(self.max)), ScalarValue::Float64(Some(self.min)), - ScalarValue::new_list(Some(centroids), DataType::Float64), + ScalarValue::List(arr), ] } @@ -600,10 +605,18 @@ impl TDigest { }; let centroids: Vec<_> = match &state[5] { - ScalarValue::List(Some(c), f) if *f.data_type() == DataType::Float64 => c - .chunks(2) - .map(|v| Centroid::new(cast_scalar_f64!(v[0]), cast_scalar_f64!(v[1]))) - .collect(), + ScalarValue::List(arr) => { + let list_array = as_list_array(arr); + let arr = list_array.values(); + + let f64arr = + as_primitive_array::(arr).expect("expected f64 array"); + f64arr + .values() + .chunks(2) + .map(|v| Centroid::new(v[0], v[1])) + .collect() + } v => panic!("invalid centroids type {v:?}"), }; diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index e1af67071260..d73c46a0f687 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -17,52 +17,52 @@ //! Utilities used in aggregates +use std::any::Any; +use std::sync::Arc; + use crate::{AggregateExpr, PhysicalSortExpr}; -use arrow::array::ArrayRef; -use arrow::datatypes::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION}; + +use arrow::array::{ArrayRef, ArrowNativeTypeOp}; use arrow_array::cast::AsArray; use arrow_array::types::{ - Decimal128Type, TimestampMicrosecondType, TimestampMillisecondType, + Decimal128Type, DecimalType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use arrow_schema::{DataType, Field}; +use arrow_buffer::{ArrowNativeType, ToByteSlice}; +use arrow_schema::{DataType, Field, SortOptions}; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; /// Convert scalar values from an accumulator into arrays. pub fn get_accum_scalar_values_as_arrays( accum: &dyn Accumulator, ) -> Result> { - Ok(accum + accum .state()? .iter() .map(|s| s.to_array_of_size(1)) - .collect::>()) + .collect() } -/// Computes averages for `Decimal128` values, checking for overflow +/// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow /// -/// This is needed because different precisions for Decimal128 can +/// This is needed because different precisions for Decimal128/Decimal256 can /// store different ranges of values and thus sum/count may not fit in /// the target type. /// /// For example, the precision is 3, the max of value is `999` and the min /// value is `-999` -pub(crate) struct Decimal128Averager { +pub(crate) struct DecimalAverager { /// scale factor for sum values (10^sum_scale) - sum_mul: i128, + sum_mul: T::Native, /// scale factor for target (10^target_scale) - target_mul: i128, - /// The minimum output value possible to represent with the target precision - target_min: i128, - /// The maximum output value possible to represent with the target precision - target_max: i128, + target_mul: T::Native, + /// the output precision + target_precision: u8, } -impl Decimal128Averager { - /// Create a new `Decimal128Averager`: +impl DecimalAverager { + /// Create a new `DecimalAverager`: /// /// * sum_scale: the scale of `sum` values passed to [`Self::avg`] /// * target_precision: the output precision @@ -74,17 +74,23 @@ impl Decimal128Averager { target_precision: u8, target_scale: i8, ) -> Result { - let sum_mul = 10_i128.pow(sum_scale as u32); - let target_mul = 10_i128.pow(target_scale as u32); - let target_min = MIN_DECIMAL_FOR_EACH_PRECISION[target_precision as usize - 1]; - let target_max = MAX_DECIMAL_FOR_EACH_PRECISION[target_precision as usize - 1]; + let sum_mul = T::Native::from_usize(10_usize) + .map(|b| b.pow_wrapping(sum_scale as u32)) + .ok_or(DataFusionError::Internal( + "Failed to compute sum_mul in DecimalAverager".to_string(), + ))?; + + let target_mul = T::Native::from_usize(10_usize) + .map(|b| b.pow_wrapping(target_scale as u32)) + .ok_or(DataFusionError::Internal( + "Failed to compute target_mul in DecimalAverager".to_string(), + ))?; if target_mul >= sum_mul { Ok(Self { sum_mul, target_mul, - target_min, - target_max, + target_precision, }) } else { // can't convert the lit decimal to the returned data type @@ -92,17 +98,21 @@ impl Decimal128Averager { } } - /// Returns the `sum`/`count` as a i128 Decimal128 with + /// Returns the `sum`/`count` as a i128/i256 Decimal128/Decimal256 with /// target_scale and target_precision and reporting overflow. /// /// * sum: The total sum value stored as Decimal128 with sum_scale /// (passed to `Self::try_new`) - /// * count: total count, stored as a i128 (*NOT* a Decimal128 value) + /// * count: total count, stored as a i128/i256 (*NOT* a Decimal128/Decimal256 value) #[inline(always)] - pub fn avg(&self, sum: i128, count: i128) -> Result { - if let Some(value) = sum.checked_mul(self.target_mul / self.sum_mul) { - let new_value = value / count; - if new_value >= self.target_min && new_value <= self.target_max { + pub fn avg(&self, sum: T::Native, count: T::Native) -> Result { + if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) { + let new_value = value.div_wrapping(count); + + let validate = + T::validate_decimal_precision(new_value, self.target_precision); + + if validate.is_ok() { Ok(new_value) } else { exec_err!("Arithmetic Overflow in AvgAccumulator") @@ -161,21 +171,17 @@ pub fn adjust_output_array( } /// Downcast a `Box` or `Arc` -/// and return the inner trait object as [`Any`](std::any::Any) so +/// and return the inner trait object as [`Any`] so /// that it can be downcast to a specific implementation. /// /// This method is used when implementing the `PartialEq` /// for [`AggregateExpr`] aggregation expressions and allows comparing the equality /// between the trait objects. pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { - if any.is::>() { - any.downcast_ref::>() - .unwrap() - .as_any() - } else if any.is::>() { - any.downcast_ref::>() - .unwrap() - .as_any() + if let Some(obj) = any.downcast_ref::>() { + obj.as_any() + } else if let Some(obj) = any.downcast_ref::>() { + obj.as_any() } else { any } @@ -200,3 +206,26 @@ pub(crate) fn ordering_fields( }) .collect() } + +/// Selects the sort option attribute from all the given `PhysicalSortExpr`s. +pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec { + ordering_req.iter().map(|item| item.options).collect() +} + +/// A wrapper around a type to provide hash for floats +#[derive(Copy, Clone, Debug)] +pub(crate) struct Hashable(pub T); + +impl std::hash::Hash for Hashable { + fn hash(&self, state: &mut H) { + self.0.to_byte_slice().hash(state) + } +} + +impl PartialEq for Hashable { + fn eq(&self, other: &Self) -> bool { + self.0.is_eq(other.0) + } +} + +impl Eq for Hashable {} diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs index a720dd833a87..d82c5ad5626f 100644 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ b/datafusion/physical-expr/src/aggregate/variance.rs @@ -519,13 +519,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index 990c643c6b08..6d36e2233cdd 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -17,19 +17,20 @@ //! Interval and selectivity in [`AnalysisContext`] +use std::fmt::Debug; +use std::sync::Arc; + use crate::expressions::Column; -use crate::intervals::cp_solver::PropagationResult; -use crate::intervals::{cardinality_ratio, ExprIntervalGraph, Interval, IntervalBound}; +use crate::intervals::cp_solver::{ExprIntervalGraph, PropagationResult}; use crate::utils::collect_columns; use crate::PhysicalExpr; use arrow::datatypes::Schema; +use datafusion_common::stats::Precision; use datafusion_common::{ internal_err, ColumnStatistics, DataFusionError, Result, ScalarValue, }; - -use std::fmt::Debug; -use std::sync::Arc; +use datafusion_expr::interval_arithmetic::{cardinality_ratio, Interval}; /// The shared context used during the analysis of an expression. Includes /// the boundaries for all known columns. @@ -37,7 +38,7 @@ use std::sync::Arc; pub struct AnalysisContext { // A list of known column boundaries, ordered by the index // of the column in the current schema. - pub boundaries: Option>, + pub boundaries: Vec, /// The estimated percentage of rows that this expression would select, if /// it were to be used as a boolean predicate on a filter. The value will be /// between 0.0 (selects nothing) and 1.0 (selects everything). @@ -47,7 +48,7 @@ pub struct AnalysisContext { impl AnalysisContext { pub fn new(boundaries: Vec) -> Self { Self { - boundaries: Some(boundaries), + boundaries, selectivity: None, } } @@ -58,48 +59,79 @@ impl AnalysisContext { } /// Create a new analysis context from column statistics. - pub fn from_statistics( + pub fn try_from_statistics( input_schema: &Schema, statistics: &[ColumnStatistics], - ) -> Self { - let mut column_boundaries = vec![]; - for (idx, stats) in statistics.iter().enumerate() { - column_boundaries.push(ExprBoundaries::from_column( - stats, - input_schema.fields()[idx].name().clone(), - idx, - )); - } - Self::new(column_boundaries) + ) -> Result { + statistics + .iter() + .enumerate() + .map(|(idx, stats)| ExprBoundaries::try_from_column(input_schema, stats, idx)) + .collect::>>() + .map(Self::new) } } -/// Represents the boundaries of the resulting value from a physical expression, -/// if it were to be an expression, if it were to be evaluated. +/// Represents the boundaries (e.g. min and max values) of a particular column +/// +/// This is used range analysis of expressions, to determine if the expression +/// limits the value of particular columns (e.g. analyzing an expression such as +/// `time < 50` would result in a boundary interval for `time` having a max +/// value of `50`). #[derive(Clone, Debug, PartialEq)] pub struct ExprBoundaries { pub column: Column, /// Minimum and maximum values this expression can have. pub interval: Interval, /// Maximum number of distinct values this expression can produce, if known. - pub distinct_count: Option, + pub distinct_count: Precision, } impl ExprBoundaries { /// Create a new `ExprBoundaries` object from column level statistics. - pub fn from_column(stats: &ColumnStatistics, col: String, index: usize) -> Self { - Self { - column: Column::new(&col, index), - interval: Interval::new( - IntervalBound::new_closed( - stats.min_value.clone().unwrap_or(ScalarValue::Null), - ), - IntervalBound::new_closed( - stats.max_value.clone().unwrap_or(ScalarValue::Null), - ), - ), - distinct_count: stats.distinct_count, - } + pub fn try_from_column( + schema: &Schema, + col_stats: &ColumnStatistics, + col_index: usize, + ) -> Result { + let field = &schema.fields()[col_index]; + let empty_field = + ScalarValue::try_from(field.data_type()).unwrap_or(ScalarValue::Null); + let interval = Interval::try_new( + col_stats + .min_value + .get_value() + .cloned() + .unwrap_or(empty_field.clone()), + col_stats + .max_value + .get_value() + .cloned() + .unwrap_or(empty_field), + )?; + let column = Column::new(field.name(), col_index); + Ok(ExprBoundaries { + column, + interval, + distinct_count: col_stats.distinct_count.clone(), + }) + } + + /// Create `ExprBoundaries` that represent no known bounds for all the + /// columns in `schema` + pub fn try_new_unbounded(schema: &Schema) -> Result> { + schema + .fields() + .iter() + .enumerate() + .map(|(i, field)| { + Ok(Self { + column: Column::new(field.name(), i), + interval: Interval::make_unbounded(field.data_type())?, + distinct_count: Precision::Absent, + }) + }) + .collect() } } @@ -121,37 +153,36 @@ impl ExprBoundaries { pub fn analyze( expr: &Arc, context: AnalysisContext, + schema: &Schema, ) -> Result { - let target_boundaries = context.boundaries.ok_or_else(|| { - DataFusionError::Internal("No column exists at the input to filter".to_string()) - })?; + let target_boundaries = context.boundaries; - let mut graph = ExprIntervalGraph::try_new(expr.clone())?; + let mut graph = ExprIntervalGraph::try_new(expr.clone(), schema)?; - let columns: Vec> = collect_columns(expr) + let columns = collect_columns(expr) .into_iter() - .map(|c| Arc::new(c) as Arc) - .collect(); - - let target_expr_and_indices: Vec<(Arc, usize)> = - graph.gather_node_indices(columns.as_slice()); - - let mut target_indices_and_boundaries: Vec<(usize, Interval)> = - target_expr_and_indices - .iter() - .filter_map(|(expr, i)| { - target_boundaries.iter().find_map(|bound| { - expr.as_any() - .downcast_ref::() - .filter(|expr_column| bound.column.eq(*expr_column)) - .map(|_| (*i, bound.interval.clone())) - }) + .map(|c| Arc::new(c) as _) + .collect::>(); + + let target_expr_and_indices = graph.gather_node_indices(columns.as_slice()); + + let mut target_indices_and_boundaries = target_expr_and_indices + .iter() + .filter_map(|(expr, i)| { + target_boundaries.iter().find_map(|bound| { + expr.as_any() + .downcast_ref::() + .filter(|expr_column| bound.column.eq(*expr_column)) + .map(|_| (*i, bound.interval.clone())) }) - .collect(); + }) + .collect::>(); - match graph.update_ranges(&mut target_indices_and_boundaries)? { + match graph + .update_ranges(&mut target_indices_and_boundaries, Interval::CERTAINLY_TRUE)? + { PropagationResult::Success => { - shrink_boundaries(expr, graph, target_boundaries, target_expr_and_indices) + shrink_boundaries(graph, target_boundaries, target_expr_and_indices) } PropagationResult::Infeasible => { Ok(AnalysisContext::new(target_boundaries).with_selectivity(0.0)) @@ -167,8 +198,7 @@ pub fn analyze( /// Following this, it constructs and returns a new `AnalysisContext` with the /// updated parameters. fn shrink_boundaries( - expr: &Arc, - mut graph: ExprIntervalGraph, + graph: ExprIntervalGraph, mut target_boundaries: Vec, target_expr_and_indices: Vec<(Arc, usize)>, ) -> Result { @@ -183,21 +213,8 @@ fn shrink_boundaries( }; } }); - let graph_nodes = graph.gather_node_indices(&[expr.clone()]); - let (_, root_index) = graph_nodes.first().ok_or_else(|| { - DataFusionError::Internal("Error in constructing predicate graph".to_string()) - })?; - let final_result = graph.get_interval(*root_index); - - // If during selectivity calculation we encounter an error, use 1.0 as cardinality estimate - // safest estimate(e.q largest possible value). - let selectivity = calculate_selectivity( - &final_result.lower.value, - &final_result.upper.value, - &target_boundaries, - &initial_boundaries, - ) - .unwrap_or(1.0); + + let selectivity = calculate_selectivity(&target_boundaries, &initial_boundaries); if !(0.0..=1.0).contains(&selectivity) { return internal_err!("Selectivity is out of limit: {}", selectivity); @@ -209,33 +226,17 @@ fn shrink_boundaries( /// This function calculates the filter predicate's selectivity by comparing /// the initial and pruned column boundaries. Selectivity is defined as the /// ratio of rows in a table that satisfy the filter's predicate. -/// -/// An exact propagation result at the root, i.e. `[true, true]` or `[false, false]`, -/// leads to early exit (returning a selectivity value of either 1.0 or 0.0). In such -/// a case, `[true, true]` indicates that all data values satisfy the predicate (hence, -/// selectivity is 1.0), and `[false, false]` suggests that no data value meets the -/// predicate (therefore, selectivity is 0.0). fn calculate_selectivity( - lower_value: &ScalarValue, - upper_value: &ScalarValue, target_boundaries: &[ExprBoundaries], initial_boundaries: &[ExprBoundaries], -) -> Result { - match (lower_value, upper_value) { - (ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(true))) => Ok(1.0), - (ScalarValue::Boolean(Some(false)), ScalarValue::Boolean(Some(false))) => Ok(0.0), - _ => { - // Since the intervals are assumed uniform and the values - // are not correlated, we need to multiply the selectivities - // of multiple columns to get the overall selectivity. - target_boundaries.iter().enumerate().try_fold( - 1.0, - |acc, (i, ExprBoundaries { interval, .. })| { - let temp = - cardinality_ratio(&initial_boundaries[i].interval, interval)?; - Ok(acc * temp) - }, - ) - } - } +) -> f64 { + // Since the intervals are assumed uniform and the values + // are not correlated, we need to multiply the selectivities + // of multiple columns to get the overall selectivity. + initial_boundaries + .iter() + .zip(target_boundaries.iter()) + .fold(1.0, |acc, (initial, target)| { + acc * cardinality_ratio(&initial.interval, &target.interval) + }) } diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 34fbfc3c0269..9665116b04ab 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -17,18 +17,29 @@ //! Array expressions +use std::any::type_name; +use std::collections::HashSet; +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + use arrow::array::*; -use arrow::buffer::{Buffer, OffsetBuffer}; +use arrow::buffer::OffsetBuffer; use arrow::compute; use arrow::datatypes::{DataType, Field, UInt64Type}; +use arrow::row::{RowConverter, SortField}; use arrow_buffer::NullBuffer; -use core::any::type_name; -use datafusion_common::cast::{as_generic_string_array, as_int64_array, as_list_array}; -use datafusion_common::{exec_err, internal_err, not_impl_err, plan_err, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::ColumnarValue; + +use arrow_schema::{FieldRef, SortOptions}; +use datafusion_common::cast::{ + as_generic_list_array, as_generic_string_array, as_int64_array, as_large_list_array, + as_list_array, as_null_array, as_string_array, +}; +use datafusion_common::utils::{array_into_list_array, list_ndims}; +use datafusion_common::{ + exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, +}; + use itertools::Itertools; -use std::sync::Arc; macro_rules! downcast_arg { ($ARG:expr, $ARRAY_TYPE:ident) => {{ @@ -41,84 +52,105 @@ macro_rules! downcast_arg { }}; } -/// Downcasts multiple arguments into a single concrete type -/// $ARGS: &[ArrayRef] -/// $ARRAY_TYPE: type to downcast to +/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. /// -/// $returns a Vec<$ARRAY_TYPE> -macro_rules! downcast_vec { - ($ARGS:expr, $ARRAY_TYPE:ident) => {{ - $ARGS - .iter() - .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() { - Some(array) => Ok(array), - _ => internal_err!("failed to downcast"), - }) - }}; -} - -macro_rules! new_builder { - (BooleanBuilder, $len:expr) => { - BooleanBuilder::with_capacity($len) - }; - (StringBuilder, $len:expr) => { - StringBuilder::new() - }; - (LargeStringBuilder, $len:expr) => { - LargeStringBuilder::new() - }; - ($el:ident, $len:expr) => {{ - <$el>::with_capacity($len) - }}; -} - -/// Combines multiple arrays into a single ListArray +/// # Arguments /// -/// $ARGS: slice of arrays, each with $ARRAY_TYPE -/// $ARRAY_TYPE: the type of the list elements -/// $BUILDER_TYPE: the type of ArrayBuilder for the list elements +/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared. /// -/// Returns: a ListArray where the elements each have the same type as -/// $ARRAY_TYPE and each element have a length of $ARGS.len() -macro_rules! array { - ($ARGS:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{ - let builder = new_builder!($BUILDER_TYPE, $ARGS[0].len()); - let mut builder = - ListBuilder::<$BUILDER_TYPE>::with_capacity(builder, $ARGS.len()); - - let num_rows = $ARGS[0].len(); - assert!( - $ARGS.iter().all(|a| a.len() == num_rows), - "all arguments must have the same number of rows" +/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared. +/// +/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison. +/// +/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality. +/// +/// # Returns +/// +/// Returns a `Result` representing the comparison results. The result may contain an error if there are issues with the computation. +/// +/// # Example +/// +/// ```text +/// compare_element_to_list( +/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false] +/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false] +/// +/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false] +/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false] +/// ) +/// ``` +fn compare_element_to_list( + list_array_row: &dyn Array, + element_array: &dyn Array, + row_index: usize, + eq: bool, +) -> Result { + if list_array_row.data_type() != element_array.data_type() { + return exec_err!( + "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.", + list_array_row.data_type(), + element_array.data_type() ); + } + + let indices = UInt32Array::from(vec![row_index as u32]); + let element_array_row = arrow::compute::take(element_array, &indices, None)?; + + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let res = match element_array_row.data_type() { + // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop + DataType::List(_) => { + // compare each element of the from array + let element_array_row_inner = as_list_array(&element_array_row)?.value(0); + let list_array_row_inner = as_list_array(list_array_row)?; - // for each entry in the array - for index in 0..num_rows { - // for each column - for arg in $ARGS { - match arg.as_any().downcast_ref::<$ARRAY_TYPE>() { - // Copy the source array value into the target ListArray - Some(arr) => { - if arr.is_valid(index) { - builder.values().append_value(arr.value(index)); + list_array_row_inner + .iter() + // compare element by element the current row of list_array + .map(|row| { + row.map(|row| { + if eq { + row.eq(&element_array_row_inner) } else { - builder.values().append_null(); + row.ne(&element_array_row_inner) } - } - None => match arg.as_any().downcast_ref::() { - Some(arr) => { - for _ in 0..arr.len() { - builder.values().append_null(); - } + }) + }) + .collect::() + } + DataType::LargeList(_) => { + // compare each element of the from array + let element_array_row_inner = + as_large_list_array(&element_array_row)?.value(0); + let list_array_row_inner = as_large_list_array(list_array_row)?; + + list_array_row_inner + .iter() + // compare element by element the current row of list_array + .map(|row| { + row.map(|row| { + if eq { + row.eq(&element_array_row_inner) + } else { + row.ne(&element_array_row_inner) } - None => return internal_err!("failed to downcast"), - }, - } + }) + }) + .collect::() + } + _ => { + let element_arr = Scalar::new(element_array_row); + // use not_distinct so we can compare NULL + if eq { + arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? + } else { + arrow_ord::cmp::distinct(&list_array_row, &element_arr)? } - builder.append(true); } - Arc::new(builder.finish()) - }}; + }; + + Ok(res) } /// Returns the length of a concrete array dimension @@ -152,36 +184,11 @@ fn compute_array_length( value = downcast_arg!(value, ListArray).value(0); current_dimension += 1; } - _ => return Ok(None), - } - } -} - -/// Returns the dimension of the array -fn compute_array_ndims(arr: Option) -> Result> { - Ok(compute_array_ndims_with_datatype(arr)?.0) -} - -/// Returns the dimension and the datatype of elements of the array -fn compute_array_ndims_with_datatype( - arr: Option, -) -> Result<(Option, DataType)> { - let mut res: u64 = 1; - let mut value = match arr { - Some(arr) => arr, - None => return Ok((None, DataType::Null)), - }; - if value.is_empty() { - return Ok((None, DataType::Null)); - } - - loop { - match value.data_type() { - DataType::List(..) => { - value = downcast_arg!(value, ListArray).value(0); - res += 1; + DataType::LargeList(..) => { + value = downcast_arg!(value, LargeListArray).value(0); + current_dimension += 1; } - data_type => return Ok((Some(res), data_type.clone())), + _ => return Ok(None), } } } @@ -210,10 +217,10 @@ fn compute_array_dims(arr: Option) -> Result>>> fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { let data_type = args[0].data_type(); - if !args - .iter() - .all(|arg| arg.data_type().equals_datatype(data_type)) - { + if !args.iter().all(|arg| { + arg.data_type().equals_datatype(data_type) + || arg.data_type().equals_datatype(&DataType::Null) + }) { let types = args.iter().map(|arg| arg.data_type()).collect::>(); return plan_err!("{name} received incompatible types: '{types:?}'."); } @@ -261,14 +268,8 @@ macro_rules! call_array_function { }}; } -#[derive(Debug)] -enum ListOrNull<'a> { - List(&'a dyn Array), - Null, -} - /// Convert one or more [`ArrayRef`] of the same type into a -/// `ListArray` +/// `ListArray` or 'LargeListArray' depending on the offset size. /// /// # Example (non nested) /// @@ -307,496 +308,794 @@ enum ListOrNull<'a> { /// └──────────────┘ └──────────────┘ └─────────────────────────────┘ /// col1 col2 output /// ``` -fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { +fn array_array( + args: &[ArrayRef], + data_type: DataType, +) -> Result { // do not accept 0 arguments. if args.is_empty() { return plan_err!("Array requires at least one argument"); } - let res = match data_type { - DataType::List(..) => { - let mut arrays = vec![]; - let mut row_count = 0; + let mut data = vec![]; + let mut total_len = 0; + for arg in args { + let arg_data = if arg.as_any().is::() { + ArrayData::new_empty(&data_type) + } else { + arg.to_data() + }; + total_len += arg_data.len(); + data.push(arg_data); + } - for arg in args { - let list_arr = arg.as_list_opt::(); - if let Some(list_arr) = list_arr { - // Assume number of rows is the same for all arrays - row_count = list_arr.len(); - arrays.push(ListOrNull::List(list_arr)); - } else if arg.as_any().downcast_ref::().is_some() { - arrays.push(ListOrNull::Null); - } else { - return internal_err!("Unsupported argument type for array"); - } - } + let mut offsets: Vec = Vec::with_capacity(total_len); + offsets.push(O::usize_as(0)); - let mut total_capacity = 0; - let mut array_data = vec![]; - for arr in arrays.iter() { - if let ListOrNull::List(arr) = arr { - total_capacity += arr.len(); - array_data.push(arr.to_data()); - } - } - let capacity = Capacities::Array(total_capacity); - let array_data = array_data.iter().collect(); + let capacity = Capacities::Array(total_len); + let data_ref = data.iter().collect::>(); + let mut mutable = MutableArrayData::with_capacities(data_ref, true, capacity); - let mut mutable = - MutableArrayData::with_capacities(array_data, true, capacity); - - for i in 0..row_count { - let mut nulls = 0; - for (j, arr) in arrays.iter().enumerate() { - match arr { - ListOrNull::List(_) => { - mutable.extend(j - nulls, i, i + 1); - } - ListOrNull::Null => { - mutable.extend_nulls(1); - nulls += 1; - } - } - } + let num_rows = args[0].len(); + for row_idx in 0..num_rows { + for (arr_idx, arg) in args.iter().enumerate() { + if !arg.as_any().is::() + && !arg.is_null(row_idx) + && arg.is_valid(row_idx) + { + mutable.extend(arr_idx, row_idx, row_idx + 1); + } else { + mutable.extend_nulls(1); } - - let list_data_type = - DataType::List(Arc::new(Field::new("item", data_type, true))); - - let offsets: Vec = (0..row_count as i32 + 1) - .map(|i| i * arrays.len() as i32) - .collect(); - - let list_data = ArrayData::builder(list_data_type) - .len(row_count) - .buffers(vec![Buffer::from_vec(offsets)]) - .add_child_data(mutable.freeze()) - .build()?; - Arc::new(ListArray::from(list_data)) - } - DataType::Utf8 => array!(args, StringArray, StringBuilder), - DataType::LargeUtf8 => array!(args, LargeStringArray, LargeStringBuilder), - DataType::Boolean => array!(args, BooleanArray, BooleanBuilder), - DataType::Float32 => array!(args, Float32Array, Float32Builder), - DataType::Float64 => array!(args, Float64Array, Float64Builder), - DataType::Int8 => array!(args, Int8Array, Int8Builder), - DataType::Int16 => array!(args, Int16Array, Int16Builder), - DataType::Int32 => array!(args, Int32Array, Int32Builder), - DataType::Int64 => array!(args, Int64Array, Int64Builder), - DataType::UInt8 => array!(args, UInt8Array, UInt8Builder), - DataType::UInt16 => array!(args, UInt16Array, UInt16Builder), - DataType::UInt32 => array!(args, UInt32Array, UInt32Builder), - DataType::UInt64 => array!(args, UInt64Array, UInt64Builder), - data_type => { - return not_impl_err!("Array is not implemented for type '{data_type:?}'.") } - }; + offsets.push(O::usize_as(mutable.len())); + } + let data = mutable.freeze(); - Ok(res) + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) } -/// Convert one or more [`ColumnarValue`] of the same type into a -/// `ListArray` -/// -/// See [`array_array`] for more details. -fn array(values: &[ColumnarValue]) -> Result { - let arrays: Vec = values - .iter() - .map(|x| match x { - ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - }) - .collect(); - - let mut data_type = None; - for arg in &arrays { +/// `make_array` SQL function +pub fn make_array(arrays: &[ArrayRef]) -> Result { + let mut data_type = DataType::Null; + for arg in arrays { let arg_data_type = arg.data_type(); if !arg_data_type.equals_datatype(&DataType::Null) { - data_type = Some(arg_data_type.clone()); + data_type = arg_data_type.clone(); break; - } else { - data_type = Some(DataType::Null); } } match data_type { - // empty array - None => Ok(ColumnarValue::Scalar(ScalarValue::new_list( - Some(vec![]), - DataType::Null, - ))), - // all nulls, set default data type as int32 - Some(DataType::Null) => { - let nulls = arrays.len(); - let null_arr = Int32Array::from(vec![None; nulls]); - let field = Arc::new(Field::new("item", DataType::Int32, true)); - let offsets = OffsetBuffer::from_lengths([nulls]); - let values = Arc::new(null_arr) as ArrayRef; - let nulls = None; - Ok(ColumnarValue::Array(Arc::new(ListArray::new( - field, offsets, values, nulls, - )))) - } - Some(data_type) => Ok(ColumnarValue::Array(array_array( - arrays.as_slice(), - data_type, - )?)), + // Either an empty array or all nulls: + DataType::Null => { + let array = + new_null_array(&DataType::Null, arrays.iter().map(|a| a.len()).sum()); + Ok(Arc::new(array_into_list_array(array))) + } + DataType::LargeList(..) => array_array::(arrays, data_type), + _ => array_array::(arrays, data_type), } } -/// `make_array` SQL function -pub fn make_array(arrays: &[ArrayRef]) -> Result { - let values: Vec = arrays - .iter() - .map(|x| ColumnarValue::Array(x.clone())) - .collect(); +fn general_array_element( + array: &GenericListArray, + indexes: &Int64Array, +) -> Result +where + i64: TryInto, +{ + let values = array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + + // use_nulls: true, we don't construct List for array_element, so we need explicit nulls. + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, capacity); + + fn adjusted_array_index(index: i64, len: O) -> Result> + where + i64: TryInto, + { + let index: O = index.try_into().map_err(|_| { + DataFusionError::Execution(format!( + "array_element got invalid index: {}", + index + )) + })?; + // 0 ~ len - 1 + let adjusted_zero_index = if index < O::usize_as(0) { + index + len + } else { + index - O::usize_as(1) + }; + + if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len { + Ok(Some(adjusted_zero_index)) + } else { + // Out of bounds + Ok(None) + } + } + + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + let start = offset_window[0]; + let end = offset_window[1]; + let len = end - start; + + // array is null + if len == O::usize_as(0) { + mutable.extend_nulls(1); + continue; + } + + let index = adjusted_array_index::(indexes.value(row_index), len)?; + + if let Some(index) = index { + let start = start.as_usize() + index.as_usize(); + mutable.extend(0, start, start + 1_usize); + } else { + // Index out of bounds + mutable.extend_nulls(1); + } + } + + let data = mutable.freeze(); + Ok(arrow_array::make_array(data)) +} + +/// array_element SQL function +/// +/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index. +/// `array_element(array, index)` +/// +/// For example: +/// > array_element(\[1, 2, 3], 2) -> 2 +pub fn array_element(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_element needs two arguments"); + } - match array(values.as_slice())? { - ColumnarValue::Array(array) => Ok(array), - ColumnarValue::Scalar(scalar) => Ok(scalar.to_array().clone()), + match &args[0].data_type() { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + let indexes = as_int64_array(&args[1])?; + general_array_element::(array, indexes) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + let indexes = as_int64_array(&args[1])?; + general_array_element::(array, indexes) + } + _ => exec_err!( + "array_element does not support type: {:?}", + args[0].data_type() + ), } } -fn return_empty(return_null: bool, data_type: DataType) -> Arc { - if return_null { - new_null_array(&data_type, 1) +fn general_except( + l: &GenericListArray, + r: &GenericListArray, + field: &FieldRef, +) -> Result> { + let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; + + let l_values = l.values().to_owned(); + let r_values = r.values().to_owned(); + let l_values = converter.convert_columns(&[l_values])?; + let r_values = converter.convert_columns(&[r_values])?; + + let mut offsets = Vec::::with_capacity(l.len() + 1); + offsets.push(OffsetSize::usize_as(0)); + + let mut rows = Vec::with_capacity(l_values.num_rows()); + let mut dedup = HashSet::new(); + + for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) { + let l_slice = l_w[0].as_usize()..l_w[1].as_usize(); + let r_slice = r_w[0].as_usize()..r_w[1].as_usize(); + for i in r_slice { + let right_row = r_values.row(i); + dedup.insert(right_row); + } + for i in l_slice { + let left_row = l_values.row(i); + if dedup.insert(left_row) { + rows.push(left_row); + } + } + + offsets.push(OffsetSize::usize_as(rows.len())); + dedup.clear(); + } + + if let Some(values) = converter.convert_rows(rows)?.first() { + Ok(GenericListArray::::new( + field.to_owned(), + OffsetBuffer::new(offsets.into()), + values.to_owned(), + l.nulls().cloned(), + )) } else { - new_empty_array(&data_type) + internal_err!("array_except failed to convert rows") } } -macro_rules! list_slice { - ($ARRAY:expr, $I:expr, $J:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); - if $I == 0 && $J == 0 || $ARRAY.is_empty() { - return return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone()); +pub fn array_except(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_except needs two arguments"); + } + + let array1 = &args[0]; + let array2 = &args[1]; + + match (array1.data_type(), array2.data_type()) { + (DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()), + (DataType::List(field), DataType::List(_)) => { + check_datatypes("array_except", &[array1, array2])?; + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = general_except::(list1, list2, field)?; + Ok(Arc::new(result)) } + (DataType::LargeList(field), DataType::LargeList(_)) => { + check_datatypes("array_except", &[array1, array2])?; + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = general_except::(list1, list2, field)?; + Ok(Arc::new(result)) + } + (dt1, dt2) => { + internal_err!("array_except got unexpected types: {dt1:?} and {dt2:?}") + } + } +} - let i = if $I < 0 { - if $I.abs() as usize > array.len() { - return return_empty(true, $ARRAY.data_type().clone()); - } +/// array_slice SQL function +/// +/// We follow the behavior of array_slice in DuckDB +/// Note that array_slice is 1-indexed. And there are two additional arguments `from` and `to` in array_slice. +/// +/// > array_slice(array, from, to) +/// +/// Positive index is treated as the index from the start of the array. If the +/// `from` index is smaller than 1, it is treated as 1. If the `to` index is larger than the +/// length of the array, it is treated as the length of the array. +/// +/// Negative index is treated as the index from the end of the array. If the index +/// is larger than the length of the array, it is NOT VALID, either in `from` or `to`. +/// The `to` index is exclusive like python slice syntax. +/// +/// See test cases in `array.slt` for more details. +pub fn array_slice(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_slice needs three arguments"); + } - (array.len() as i64 + $I + 1) as usize + let array_data_type = args[0].data_type(); + match array_data_type { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + let from_array = as_int64_array(&args[1])?; + let to_array = as_int64_array(&args[2])?; + general_array_slice::(array, from_array, to_array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + let from_array = as_int64_array(&args[1])?; + let to_array = as_int64_array(&args[2])?; + general_array_slice::(array, from_array, to_array) + } + _ => exec_err!("array_slice does not support type: {:?}", array_data_type), + } +} + +fn general_array_slice( + array: &GenericListArray, + from_array: &Int64Array, + to_array: &Int64Array, +) -> Result +where + i64: TryInto, +{ + let values = array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + + // use_nulls: false, we don't need nulls but empty array for array_slice, so we don't need explicit nulls but adjust offset to indicate nulls. + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + + // We have the slice syntax compatible with DuckDB v0.8.1. + // The rule `adjusted_from_index` and `adjusted_to_index` follows the rule of array_slice in duckdb. + + fn adjusted_from_index(index: i64, len: O) -> Result> + where + i64: TryInto, + { + // 0 ~ len - 1 + let adjusted_zero_index = if index < 0 { + if let Ok(index) = index.try_into() { + index + len + } else { + return exec_err!("array_slice got invalid index: {}", index); + } } else { - if $I == 0 { - 1 + // array_slice(arr, 1, to) is the same as array_slice(arr, 0, to) + if let Ok(index) = index.try_into() { + std::cmp::max(index - O::usize_as(1), O::usize_as(0)) } else { - $I as usize + return exec_err!("array_slice got invalid index: {}", index); } }; - let j = if $J < 0 { - if $J.abs() as usize > array.len() { - return return_empty(true, $ARRAY.data_type().clone()); - } - if $RETURN_ELEMENT { - (array.len() as i64 + $J + 1) as usize + if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len { + Ok(Some(adjusted_zero_index)) + } else { + // Out of bounds + Ok(None) + } + } + + fn adjusted_to_index(index: i64, len: O) -> Result> + where + i64: TryInto, + { + // 0 ~ len - 1 + let adjusted_zero_index = if index < 0 { + // array_slice in duckdb with negative to_index is python-like, so index itself is exclusive + if let Ok(index) = index.try_into() { + index + len - O::usize_as(1) } else { - (array.len() as i64 + $J) as usize + return exec_err!("array_slice got invalid index: {}", index); } } else { - if $J == 0 { - 1 + // array_slice(arr, from, len + 1) is the same as array_slice(arr, from, len) + if let Ok(index) = index.try_into() { + std::cmp::min(index - O::usize_as(1), len - O::usize_as(1)) } else { - if $J as usize > array.len() { - array.len() - } else { - $J as usize - } + return exec_err!("array_slice got invalid index: {}", index); } }; - if i > j || i as usize > $ARRAY.len() { - return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone()) + if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len { + Ok(Some(adjusted_zero_index)) } else { - Arc::new(array.slice((i - 1), (j + 1 - i))) + // Out of bounds + Ok(None) } - }}; -} + } -macro_rules! slice { - ($ARRAY:expr, $KEY:expr, $EXTRA_KEY:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let sliced_array: Vec> = $ARRAY - .iter() - .zip($KEY.iter()) - .zip($EXTRA_KEY.iter()) - .map(|((arr, i), j)| match (arr, i, j) { - (Some(arr), Some(i), Some(j)) => { - list_slice!(arr, i, j, $RETURN_ELEMENT, $ARRAY_TYPE) - } - (Some(arr), None, Some(j)) => { - list_slice!(arr, 1i64, j, $RETURN_ELEMENT, $ARRAY_TYPE) - } - (Some(arr), Some(i), None) => { - list_slice!(arr, i, arr.len() as i64, $RETURN_ELEMENT, $ARRAY_TYPE) - } - (Some(arr), None, None) if !$RETURN_ELEMENT => arr, - _ => return_empty($RETURN_ELEMENT, $ARRAY.value_type().clone()), - }) - .collect(); + let mut offsets = vec![O::usize_as(0)]; - // concat requires input of at least one array - if sliced_array.is_empty() { - Ok(return_empty($RETURN_ELEMENT, $ARRAY.value_type())) + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + let start = offset_window[0]; + let end = offset_window[1]; + let len = end - start; + + // len 0 indicate array is null, return empty array in this row. + if len == O::usize_as(0) { + offsets.push(offsets[row_index]); + continue; + } + + // If index is null, we consider it as the minimum / maximum index of the array. + let from_index = if from_array.is_null(row_index) { + Some(O::usize_as(0)) } else { - let vec = sliced_array - .iter() - .map(|a| a.as_ref()) - .collect::>(); - let mut i: i32 = 0; - let mut offsets = vec![i]; - offsets.extend( - vec.iter() - .map(|a| { - i += a.len() as i32; - i - }) - .collect::>(), - ); - let values = compute::concat(vec.as_slice()).unwrap(); + adjusted_from_index::(from_array.value(row_index), len)? + }; + + let to_index = if to_array.is_null(row_index) { + Some(len - O::usize_as(1)) + } else { + adjusted_to_index::(to_array.value(row_index), len)? + }; - if $RETURN_ELEMENT { - Ok(values) + if let (Some(from), Some(to)) = (from_index, to_index) { + if from <= to { + assert!(start + to <= end); + mutable.extend( + 0, + (start + from).to_usize().unwrap(), + (start + to + O::usize_as(1)).to_usize().unwrap(), + ); + offsets.push(offsets[row_index] + (to - from + O::usize_as(1))); } else { - let field = - Arc::new(Field::new("item", $ARRAY.value_type().clone(), true)); - Ok(Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - values, - None, - )?)) + // invalid range, return empty array + offsets.push(offsets[row_index]); } + } else { + // invalid range, return empty array + offsets.push(offsets[row_index]); } - }}; -} - -fn define_array_slice( - list_array: &ListArray, - key: &Int64Array, - extra_key: &Int64Array, - return_element: bool, -) -> Result { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - slice!(list_array, key, extra_key, return_element, $ARRAY_TYPE) - }; } - call_array_function!(list_array.value_type(), true) + + let data = mutable.freeze(); + + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", array.value_type(), true)), + OffsetBuffer::::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) } -pub fn array_element(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let key = as_int64_array(&args[1])?; - define_array_slice(list_array, key, key, true) +fn general_pop_front_list( + array: &GenericListArray, +) -> Result +where + i64: TryInto, +{ + let from_array = Int64Array::from(vec![2; array.len()]); + let to_array = Int64Array::from( + array + .iter() + .map(|arr| arr.map_or(0, |arr| arr.len() as i64)) + .collect::>(), + ); + general_array_slice::(array, &from_array, &to_array) +} + +fn general_pop_back_list( + array: &GenericListArray, +) -> Result +where + i64: TryInto, +{ + let from_array = Int64Array::from(vec![1; array.len()]); + let to_array = Int64Array::from( + array + .iter() + .map(|arr| arr.map_or(0, |arr| arr.len() as i64 - 1)) + .collect::>(), + ); + general_array_slice::(array, &from_array, &to_array) } -pub fn array_slice(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let key = as_int64_array(&args[1])?; - let extra_key = as_int64_array(&args[2])?; - define_array_slice(list_array, key, extra_key, false) +/// array_pop_front SQL function +pub fn array_pop_front(args: &[ArrayRef]) -> Result { + let array_data_type = args[0].data_type(); + match array_data_type { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + general_pop_front_list::(array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_pop_front_list::(array) + } + _ => exec_err!( + "array_pop_front does not support type: {:?}", + array_data_type + ), + } } +/// array_pop_back SQL function pub fn array_pop_back(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let key = vec![0; list_array.len()]; - let extra_key: Vec<_> = list_array - .iter() - .map(|x| x.map_or(0, |arr| arr.len() as i64 - 1)) - .collect(); + if args.len() != 1 { + return exec_err!("array_pop_back needs one argument"); + } - define_array_slice( - list_array, - &Int64Array::from(key), - &Int64Array::from(extra_key), + let array_data_type = args[0].data_type(); + match array_data_type { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + general_pop_back_list::(array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_pop_back_list::(array) + } + _ => exec_err!( + "array_pop_back does not support type: {:?}", + array_data_type + ), + } +} + +/// Appends or prepends elements to a ListArray. +/// +/// This function takes a ListArray, an ArrayRef, a FieldRef, and a boolean flag +/// indicating whether to append or prepend the elements. It returns a `Result` +/// representing the resulting ListArray after the operation. +/// +/// # Arguments +/// +/// * `list_array` - A reference to the ListArray to which elements will be appended/prepended. +/// * `element_array` - A reference to the Array containing elements to be appended/prepended. +/// * `field` - A reference to the Field describing the data type of the arrays. +/// * `is_append` - A boolean flag indicating whether to append (`true`) or prepend (`false`) elements. +/// +/// # Examples +/// +/// generic_append_and_prepend( +/// [1, 2, 3], 4, append => [1, 2, 3, 4] +/// 5, [6, 7, 8], prepend => [5, 6, 7, 8] +/// ) +fn generic_append_and_prepend( + list_array: &GenericListArray, + element_array: &ArrayRef, + data_type: &DataType, + is_append: bool, +) -> Result +where + i64: TryInto, +{ + let mut offsets = vec![O::usize_as(0)]; + let values = list_array.values(); + let original_data = values.to_data(); + let element_data = element_array.to_data(); + let capacity = Capacities::Array(original_data.len() + element_data.len()); + + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &element_data], false, - ) -} - -macro_rules! append { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone(); - - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - for (arr, el) in $ARRAY.iter().zip(element.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - values = downcast_arg!( - compute::concat(&[ - &values, - child_array, - &$ARRAY_TYPE::from(vec![el]) - ])? - .clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + child_array.len() as i32 + 1i32); - } - None => { - values = downcast_arg!( - compute::concat(&[ - &values, - &$ARRAY_TYPE::from(vec![el.clone()]) - ])? - .clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + 1i32); - } - } + capacity, + ); + + let values_index = 0; + let element_index = 1; + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + let start = offset_window[0].to_usize().unwrap(); + let end = offset_window[1].to_usize().unwrap(); + if is_append { + mutable.extend(values_index, start, end); + mutable.extend(element_index, row_index, row_index + 1); + } else { + mutable.extend(element_index, row_index, row_index + 1); + mutable.extend(values_index, start, end); } + offsets.push(offsets[row_index] + O::usize_as(end - start + 1)); + } - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); + let data = mutable.freeze(); - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) } -/// Array_append SQL function -pub fn array_append(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let element = &args[1]; +/// Generates an array of integers from start to stop with a given step. +/// +/// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values. +/// It returns a `Result` representing the resulting ListArray after the operation. +/// +/// # Arguments +/// +/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values. +/// +/// # Examples +/// +/// gen_range(3) => [0, 1, 2] +/// gen_range(1, 4) => [1, 2, 3] +/// gen_range(1, 7, 2) => [1, 3, 5] +pub fn gen_range(args: &[ArrayRef]) -> Result { + let (start_array, stop_array, step_array) = match args.len() { + 1 => (None, as_int64_array(&args[0])?, None), + 2 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + None, + ), + 3 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + Some(as_int64_array(&args[2])?), + ), + _ => return exec_err!("gen_range expects 1 to 3 arguments"), + }; - check_datatypes("array_append", &[arr.values(), element])?; - let res = match arr.value_type() { - DataType::List(_) => concat_internal(args)?, - DataType::Null => { - return Ok(array(&[ColumnarValue::Array(args[1].clone())])?.into_array(1)) + let mut values = vec![]; + let mut offsets = vec![0]; + for (idx, stop) in stop_array.iter().enumerate() { + let stop = stop.unwrap_or(0); + let start = start_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(0); + let step = step_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(1); + if step == 0 { + return exec_err!("step can't be 0 for function range(start [, stop, step]"); } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - append!(arr, element, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) + if step < 0 { + // Decreasing range + values.extend((stop + 1..start + 1).rev().step_by((-step) as usize)); + } else { + // Increasing range + values.extend((start..stop).step_by(step as usize)); } - }; - Ok(res) + offsets.push(values.len() as i32); + } + let arr = Arc::new(ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(Int64Array::from(values)), + None, + )?); + Ok(arr) } -macro_rules! prepend { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone(); - - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - for (arr, el) in $ARRAY.iter().zip(element.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - values = downcast_arg!( - compute::concat(&[ - &values, - &$ARRAY_TYPE::from(vec![el]), - child_array - ])? - .clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + child_array.len() as i32 + 1i32); - } - None => { - values = downcast_arg!( - compute::concat(&[ - &values, - &$ARRAY_TYPE::from(vec![el.clone()]) - ])? - .clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + 1i32); - } - } +/// Array_sort SQL function +pub fn array_sort(args: &[ArrayRef]) -> Result { + if args.is_empty() || args.len() > 3 { + return exec_err!("array_sort expects one to three arguments"); + } + + let sort_option = match args.len() { + 1 => None, + 2 => { + let sort = as_string_array(&args[1])?.value(0); + Some(SortOptions { + descending: order_desc(sort)?, + nulls_first: true, + }) + } + 3 => { + let sort = as_string_array(&args[1])?.value(0); + let nulls_first = as_string_array(&args[2])?.value(0); + Some(SortOptions { + descending: order_desc(sort)?, + nulls_first: order_nulls_first(nulls_first)?, + }) } + _ => return exec_err!("array_sort expects 1 to 3 arguments"), + }; - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); + let list_array = as_list_array(&args[0])?; + let row_count = list_array.len(); + + let mut array_lengths = vec![]; + let mut arrays = vec![]; + let mut valid = BooleanBufferBuilder::new(row_count); + for i in 0..row_count { + if list_array.is_null(i) { + array_lengths.push(0); + valid.append(false); + } else { + let arr_ref = list_array.value(i); + let arr_ref = arr_ref.as_ref(); - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; -} + let sorted_array = compute::sort(arr_ref, sort_option)?; + array_lengths.push(sorted_array.len()); + arrays.push(sorted_array); + valid.append(true); + } + } -/// Array_prepend SQL function -pub fn array_prepend(args: &[ArrayRef]) -> Result { - let element = &args[0]; - let arr = as_list_array(&args[1])?; + // Assume all arrays have the same data type + let data_type = list_array.value_type(); + let buffer = valid.finish(); + + let elements = arrays + .iter() + .map(|a| a.as_ref()) + .collect::>(); + + let list_arr = ListArray::new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::from_lengths(array_lengths), + Arc::new(compute::concat(elements.as_slice())?), + Some(NullBuffer::new(buffer)), + ); + Ok(Arc::new(list_arr)) +} + +fn order_desc(modifier: &str) -> Result { + match modifier.to_uppercase().as_str() { + "DESC" => Ok(true), + "ASC" => Ok(false), + _ => exec_err!("the second parameter of array_sort expects DESC or ASC"), + } +} + +fn order_nulls_first(modifier: &str) -> Result { + match modifier.to_uppercase().as_str() { + "NULLS FIRST" => Ok(true), + "NULLS LAST" => Ok(false), + _ => exec_err!( + "the third parameter of array_sort expects NULLS FIRST or NULLS LAST" + ), + } +} + +fn general_append_and_prepend( + args: &[ArrayRef], + is_append: bool, +) -> Result +where + i64: TryInto, +{ + let (list_array, element_array) = if is_append { + let list_array = as_generic_list_array::(&args[0])?; + let element_array = &args[1]; + check_datatypes("array_append", &[element_array, list_array.values()])?; + (list_array, element_array) + } else { + let list_array = as_generic_list_array::(&args[1])?; + let element_array = &args[0]; + check_datatypes("array_prepend", &[list_array.values(), element_array])?; + (list_array, element_array) + }; - check_datatypes("array_prepend", &[element, arr.values()])?; - let res = match arr.value_type() { - DataType::List(_) => concat_internal(args)?, + let res = match list_array.value_type() { + DataType::List(_) => concat_internal::(args)?, + DataType::LargeList(_) => concat_internal::(args)?, DataType::Null => { - return Ok(array(&[ColumnarValue::Array(args[0].clone())])?.into_array(1)) + return make_array(&[ + list_array.values().to_owned(), + element_array.to_owned(), + ]); } data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - prepend!(arr, element, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) + return generic_append_and_prepend::( + list_array, + element_array, + &data_type, + is_append, + ); } }; Ok(res) } +/// Array_append SQL function +pub fn array_append(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_append expects two arguments"); + } + + match args[0].data_type() { + DataType::LargeList(_) => general_append_and_prepend::(args, true), + _ => general_append_and_prepend::(args, true), + } +} + +/// Array_prepend SQL function +pub fn array_prepend(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_prepend expects two arguments"); + } + + match args[1].data_type() { + DataType::LargeList(_) => general_append_and_prepend::(args, false), + _ => general_append_and_prepend::(args, false), + } +} + fn align_array_dimensions(args: Vec) -> Result> { - // Find the maximum number of dimensions - let max_ndim: u64 = (*args + let args_ndim = args .iter() - .map(|arr| compute_array_ndims(Some(arr.clone()))) - .collect::>>>()? - .iter() - .max() - .unwrap()) - .unwrap(); + .map(|arg| datafusion_common::utils::list_ndims(arg.data_type())) + .collect::>(); + let max_ndim = args_ndim.iter().max().unwrap_or(&0); // Align the dimensions of the arrays let aligned_args: Result> = args .into_iter() - .map(|array| { - let ndim = compute_array_ndims(Some(array.clone()))?.unwrap(); + .zip(args_ndim.iter()) + .map(|(array, ndim)| { if ndim < max_ndim { let mut aligned_array = array.clone(); for _ in 0..(max_ndim - ndim) { - let data_type = aligned_array.as_ref().data_type().clone(); - let offsets: Vec = - (0..downcast_arg!(aligned_array, ListArray).offsets().len()) - .map(|i| i as i32) - .collect(); - let field = Arc::new(Field::new("item", data_type, true)); + let data_type = aligned_array.data_type().to_owned(); + let array_lengths = vec![1; aligned_array.len()]; + let offsets = OffsetBuffer::::from_lengths(array_lengths); aligned_array = Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(aligned_array.clone()), + Arc::new(Field::new("item", data_type, true)), + offsets, + aligned_array, None, )?) } @@ -810,201 +1109,111 @@ fn align_array_dimensions(args: Vec) -> Result> { aligned_args } -fn concat_internal(args: &[ArrayRef]) -> Result { +// Concatenate arrays on the same row. +fn concat_internal(args: &[ArrayRef]) -> Result { let args = align_array_dimensions(args.to_vec())?; - let list_arrays = - downcast_vec!(args, ListArray).collect::>>()?; + let list_arrays = args + .iter() + .map(|arg| as_generic_list_array::(arg)) + .collect::>>()?; // Assume number of rows is the same for all arrays let row_count = list_arrays[0].len(); - let capacity = Capacities::Array(list_arrays.iter().map(|a| a.len()).sum()); - let array_data: Vec<_> = list_arrays.iter().map(|a| a.to_data()).collect::>(); - let array_data: Vec<&ArrayData> = array_data.iter().collect(); - - let mut mutable = MutableArrayData::with_capacities(array_data, true, capacity); - let mut array_lens = vec![0; row_count]; - let mut null_bit_map: Vec = vec![true; row_count]; + let mut array_lengths = vec![]; + let mut arrays = vec![]; + let mut valid = BooleanBufferBuilder::new(row_count); + for i in 0..row_count { + let nulls = list_arrays + .iter() + .map(|arr| arr.is_null(i)) + .collect::>(); + + // If all the arrays are null, the concatenated array is null + let is_null = nulls.iter().all(|&x| x); + if is_null { + array_lengths.push(0); + valid.append(false); + } else { + // Get all the arrays on i-th row + let values = list_arrays + .iter() + .map(|arr| arr.value(i)) + .collect::>(); - for (i, array_len) in array_lens.iter_mut().enumerate().take(row_count) { - let null_count = mutable.null_count(); - for (j, a) in list_arrays.iter().enumerate() { - mutable.extend(j, i, i + 1); - *array_len += a.value_length(i); - } + let elements = values + .iter() + .map(|a| a.as_ref()) + .collect::>(); - // This means all arrays are null - if mutable.null_count() == null_count + list_arrays.len() { - null_bit_map[i] = false; + // Concatenated array on i-th row + let concated_array = compute::concat(elements.as_slice())?; + array_lengths.push(concated_array.len()); + arrays.push(concated_array); + valid.append(true); } } + // Assume all arrays have the same data type + let data_type = list_arrays[0].value_type(); + let buffer = valid.finish(); - let mut buffer = BooleanBufferBuilder::new(row_count); - buffer.append_slice(null_bit_map.as_slice()); - let nulls = Some(NullBuffer::from(buffer.finish())); - - let offsets: Vec = std::iter::once(0) - .chain(array_lens.iter().scan(0, |state, &x| { - *state += x; - Some(*state) - })) - .collect(); - - let builder = mutable.into_builder(); + let elements = arrays + .iter() + .map(|a| a.as_ref()) + .collect::>(); - let list = builder - .len(row_count) - .buffers(vec![Buffer::from_vec(offsets)]) - .nulls(nulls) - .build()?; + let list_arr = GenericListArray::::new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::from_lengths(array_lengths), + Arc::new(compute::concat(elements.as_slice())?), + Some(NullBuffer::new(buffer)), + ); - let list = arrow::array::make_array(list); - Ok(Arc::new(list)) + Ok(Arc::new(list_arr)) } /// Array_concat/Array_cat SQL function pub fn array_concat(args: &[ArrayRef]) -> Result { + if args.is_empty() { + return exec_err!("array_concat expects at least one arguments"); + } + let mut new_args = vec![]; for arg in args { - let (ndim, lower_data_type) = - compute_array_ndims_with_datatype(Some(arg.clone()))?; - if ndim.is_none() || ndim == Some(1) { - return not_impl_err!("Array is not type '{lower_data_type:?}'."); - } else if !lower_data_type.equals_datatype(&DataType::Null) { + let ndim = list_ndims(arg.data_type()); + let base_type = datafusion_common::utils::base_type(arg.data_type()); + if ndim == 0 { + return not_impl_err!("Array is not type '{base_type:?}'."); + } else if !base_type.eq(&DataType::Null) { new_args.push(arg.clone()); } } - concat_internal(new_args.as_slice()) -} - -macro_rules! general_repeat { - ($ELEMENT:expr, $COUNT:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone(); - - let element_array = downcast_arg!($ELEMENT, $ARRAY_TYPE); - for (el, c) in element_array.iter().zip($COUNT.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match el { - Some(el) => { - let c = if c < Some(0) { 0 } else { c.unwrap() } as usize; - let repeated_array = - [Some(el.clone())].repeat(c).iter().collect::<$ARRAY_TYPE>(); - - values = downcast_arg!( - compute::concat(&[&values, &repeated_array])?.clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + repeated_array.len() as i32); - } - None => { - offsets.push(last_offset); - } - } - } - - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); - - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; -} - -macro_rules! general_repeat_list { - ($ELEMENT:expr, $COUNT:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), ListArray).clone(); - - let element_array = downcast_arg!($ELEMENT, ListArray); - for (el, c) in element_array.iter().zip($COUNT.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match el { - Some(el) => { - let c = if c < Some(0) { 0 } else { c.unwrap() } as usize; - let repeated_vec = vec![el; c]; - - let mut i: i32 = 0; - let mut repeated_offsets = vec![i]; - repeated_offsets.extend( - repeated_vec - .clone() - .into_iter() - .map(|a| { - i += a.len() as i32; - i - }) - .collect::>(), - ); - - let mut repeated_values = downcast_arg!( - new_empty_array(&element_array.value_type()), - $ARRAY_TYPE - ) - .clone(); - for repeated_list in repeated_vec { - repeated_values = downcast_arg!( - compute::concat(&[&repeated_values, &repeated_list])?, - $ARRAY_TYPE - ) - .clone(); - } - - let field = Arc::new(Field::new( - "item", - element_array.value_type().clone(), - true, - )); - let repeated_array = ListArray::try_new( - field, - OffsetBuffer::new(repeated_offsets.clone().into()), - Arc::new(repeated_values), - None, - )?; - - values = downcast_arg!( - compute::concat(&[&values, &repeated_array,])?.clone(), - ListArray - ) - .clone(); - offsets.push(last_offset + repeated_array.len() as i32); - } - None => { - offsets.push(last_offset); - } - } - } - - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); - - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; + concat_internal::(new_args.as_slice()) } /// Array_empty SQL function pub fn array_empty(args: &[ArrayRef]) -> Result { - if args[0].as_any().downcast_ref::().is_some() { - return Ok(args[0].clone()); + if args.len() != 1 { + return exec_err!("array_empty expects one argument"); } - let array = as_list_array(&args[0])?; + if as_null_array(&args[0]).is_ok() { + // Make sure to return Boolean type. + return Ok(Arc::new(BooleanArray::new_null(args[0].len()))); + } + let array_type = args[0].data_type(); + + match array_type { + DataType::List(_) => array_empty_dispatch::(&args[0]), + DataType::LargeList(_) => array_empty_dispatch::(&args[0]), + _ => exec_err!("array_empty does not support type '{array_type:?}'."), + } +} + +fn array_empty_dispatch(array: &ArrayRef) -> Result { + let array = as_generic_list_array::(array)?; let builder = array .iter() .map(|arr| arr.map(|arr| arr.len() == arr.null_count())) @@ -1014,461 +1223,591 @@ pub fn array_empty(args: &[ArrayRef]) -> Result { /// Array_repeat SQL function pub fn array_repeat(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_repeat expects two arguments"); + } + let element = &args[0]; - let count = as_int64_array(&args[1])?; + let count_array = as_int64_array(&args[1])?; - let res = match element.data_type() { - DataType::List(field) => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - general_repeat_list!(element, count, $ARRAY_TYPE) - }; - } - call_array_function!(field.data_type(), true) + match element.data_type() { + DataType::List(_) => { + let list_array = as_list_array(element)?; + general_list_repeat::(list_array, count_array) } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - general_repeat!(element, count, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) + DataType::LargeList(_) => { + let list_array = as_large_list_array(element)?; + general_list_repeat::(list_array, count_array) } - }; + _ => general_repeat(element, count_array), + } +} - Ok(res) +/// For each element of `array[i]` repeat `count_array[i]` times. +/// +/// Assumption for the input: +/// 1. `count[i] >= 0` +/// 2. `array.len() == count_array.len()` +/// +/// For example, +/// ```text +/// array_repeat( +/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]] +/// ) +/// ``` +fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result { + let data_type = array.data_type(); + let mut new_values = vec![]; + + let count_vec = count_array + .values() + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); + + for (row_index, &count) in count_vec.iter().enumerate() { + let repeated_array = if array.is_null(row_index) { + new_null_array(data_type, count) + } else { + let original_data = array.to_data(); + let capacity = Capacities::Array(count); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + + for _ in 0..count { + mutable.extend(0, row_index, row_index + 1); + } + + let data = mutable.freeze(); + arrow_array::make_array(data) + }; + new_values.push(repeated_array); + } + + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = compute::concat(&new_values)?; + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::from_lengths(count_vec), + values, + None, + )?)) } -macro_rules! position { - ($ARRAY:expr, $ELEMENT:expr, $INDEX:expr, $ARRAY_TYPE:ident) => {{ - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - $ARRAY - .iter() - .zip(element.iter()) - .zip($INDEX.iter()) - .map(|((arr, el), i)| { - let index = match i { - Some(i) => { - if i <= 0 { - 0 - } else { - i - 1 - } - } - None => return exec_err!("initial position must not be null"), - }; +/// Handle List version of `general_repeat` +/// +/// For each element of `list_array[i]` repeat `count_array[i]` times. +/// +/// For example, +/// ```text +/// array_repeat( +/// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]] +/// ) +/// ``` +fn general_list_repeat( + list_array: &GenericListArray, + count_array: &Int64Array, +) -> Result { + let data_type = list_array.data_type(); + let value_type = list_array.value_type(); + let mut new_values = vec![]; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - - match child_array - .iter() - .skip(index as usize) - .position(|x| x == el) - { - Some(value) => Ok(Some(value as u64 + index as u64 + 1u64)), - None => Ok(None), - } - } - None => Ok(None), + let count_vec = count_array + .values() + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); + + for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) { + let list_arr = match list_array_row { + Some(list_array_row) => { + let original_data = list_array_row.to_data(); + let capacity = Capacities::Array(original_data.len() * count); + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data], + false, + capacity, + ); + + for _ in 0..count { + mutable.extend(0, 0, original_data.len()); } - }) - .collect::>()? - }}; + + let data = mutable.freeze(); + let repeated_array = arrow_array::make_array(data); + + let list_arr = GenericListArray::::try_new( + Arc::new(Field::new("item", value_type.clone(), true)), + OffsetBuffer::::from_lengths(vec![original_data.len(); count]), + repeated_array, + None, + )?; + Arc::new(list_arr) as ArrayRef + } + None => new_null_array(data_type, count), + }; + new_values.push(list_arr); + } + + let lengths = new_values.iter().map(|a| a.len()).collect::>(); + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = compute::concat(&new_values)?; + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::::from_lengths(lengths), + values, + None, + )?)) } /// Array_position SQL function pub fn array_position(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let element = &args[1]; + if args.len() < 2 || args.len() > 3 { + return exec_err!("array_position expects two or three arguments"); + } + match &args[0].data_type() { + DataType::List(_) => general_position_dispatch::(args), + DataType::LargeList(_) => general_position_dispatch::(args), + array_type => exec_err!("array_position does not support type '{array_type:?}'."), + } +} +fn general_position_dispatch(args: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&args[0])?; + let element_array = &args[1]; - let index = if args.len() == 3 { - as_int64_array(&args[2])?.clone() + check_datatypes("array_position", &[list_array.values(), element_array])?; + + let arr_from = if args.len() == 3 { + as_int64_array(&args[2])? + .values() + .to_vec() + .iter() + .map(|&x| x - 1) + .collect::>() } else { - Int64Array::from_value(0, arr.len()) + vec![0; list_array.len()] }; - check_datatypes("array_position", &[arr.values(), element])?; - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - position!(arr, element, index, $ARRAY_TYPE) - }; + // if `start_from` index is out of bounds, return error + for (arr, &from) in list_array.iter().zip(arr_from.iter()) { + if let Some(arr) = arr { + if from < 0 || from as usize >= arr.len() { + return internal_err!("start_from index out of bounds"); + } + } else { + // We will get null if we got null in the array, so we don't need to check + } } - let res = call_array_function!(arr.value_type(), true); - Ok(Arc::new(res)) + generic_position::(list_array, element_array, arr_from) } -macro_rules! positions { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array(&DataType::UInt64), UInt64Array).clone(); - for comp in $ARRAY - .iter() - .zip(element.iter()) - .map(|(arr, el)| match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - let res = child_array - .iter() - .enumerate() - .filter(|(_, x)| *x == el) - .flat_map(|(i, _)| Some((i + 1) as u64)) - .collect::(); +fn generic_position( + list_array: &GenericListArray, + element_array: &ArrayRef, + arr_from: Vec, // 0-indexed +) -> Result { + let mut data = Vec::with_capacity(list_array.len()); - Ok(res) - } - None => Ok(downcast_arg!( - new_empty_array(&DataType::UInt64), - UInt64Array - ) - .clone()), - }) - .collect::>>()? - { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty",)) - })?; - values = - downcast_arg!(compute::concat(&[&values, &comp,])?.clone(), UInt64Array) - .clone(); - offsets.push(last_offset + comp.len() as i32); - } + for (row_index, (list_array_row, &from)) in + list_array.iter().zip(arr_from.iter()).enumerate() + { + let from = from as usize; - let field = Arc::new(Field::new("item", DataType::UInt64, true)); + if let Some(list_array_row) = list_array_row { + let eq_array = + compare_element_to_list(&list_array_row, element_array, row_index, true)?; - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; + // Collect `true`s in 1-indexed positions + let index = eq_array + .iter() + .skip(from) + .position(|e| e == Some(true)) + .map(|index| (from + index + 1) as u64); + + data.push(index); + } else { + data.push(None); + } + } + + Ok(Arc::new(UInt64Array::from(data))) } /// Array_positions SQL function pub fn array_positions(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; + if args.len() != 2 { + return exec_err!("array_positions expects two arguments"); + } + let element = &args[1]; - check_datatypes("array_positions", &[arr.values(), element])?; - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - positions!(arr, element, $ARRAY_TYPE) - }; + match &args[0].data_type() { + DataType::List(_) => { + let arr = as_list_array(&args[0])?; + check_datatypes("array_positions", &[arr.values(), element])?; + general_positions::(arr, element) + } + DataType::LargeList(_) => { + let arr = as_large_list_array(&args[0])?; + check_datatypes("array_positions", &[arr.values(), element])?; + general_positions::(arr, element) + } + array_type => { + exec_err!("array_positions does not support type '{array_type:?}'.") + } } - let res = call_array_function!(arr.value_type(), true); +} - Ok(res) +fn general_positions( + list_array: &GenericListArray, + element_array: &ArrayRef, +) -> Result { + let mut data = Vec::with_capacity(list_array.len()); + + for (row_index, list_array_row) in list_array.iter().enumerate() { + if let Some(list_array_row) = list_array_row { + let eq_array = + compare_element_to_list(&list_array_row, element_array, row_index, true)?; + + // Collect `true`s in 1-indexed positions + let indexes = eq_array + .iter() + .positions(|e| e == Some(true)) + .map(|index| Some(index as u64 + 1)) + .collect::>(); + + data.push(Some(indexes)); + } else { + data.push(None); + } + } + + Ok(Arc::new( + ListArray::from_iter_primitive::(data), + )) } -macro_rules! general_remove { - ($ARRAY:expr, $ELEMENT:expr, $MAX:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone(); - - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - for ((arr, el), max) in $ARRAY.iter().zip(element.iter()).zip($MAX.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - let mut counter = 0; - let max = if max < Some(1) { 1 } else { max.unwrap() }; - - let filter_array = child_array +/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences +/// of `element_array[i]`. +/// +/// The type of each **element** in `list_array` must be the same as the type of +/// `element_array`. This function also handles nested arrays +/// ([`ListArray`] of [`ListArray`]s) +/// +/// For example, when called to remove a list array (where each element is a +/// list of int32s, the second argument are int32 arrays, and the +/// third argument is the number of occurrences to remove +/// +/// ```text +/// general_remove( +/// [1, 2, 3, 2], 2, 1 ==> [1, 3, 2] (only the first 2 is removed) +/// [4, 5, 6, 5], 5, 2 ==> [4, 6] (both 5s are removed) +/// ) +/// ``` +fn general_remove( + list_array: &GenericListArray, + element_array: &ArrayRef, + arr_n: Vec, +) -> Result { + let data_type = list_array.value_type(); + let mut new_values = vec![]; + // Build up the offsets for the final output array + let mut offsets = Vec::::with_capacity(arr_n.len() + 1); + offsets.push(OffsetSize::zero()); + + // n is the number of elements to remove in this row + for (row_index, (list_array_row, n)) in + list_array.iter().zip(arr_n.iter()).enumerate() + { + match list_array_row { + Some(list_array_row) => { + let eq_array = compare_element_to_list( + &list_array_row, + element_array, + row_index, + false, + )?; + + // We need to keep at most first n elements as `false`, which represent the elements to remove. + let eq_array = if eq_array.false_count() < *n as usize { + eq_array + } else { + let mut count = 0; + eq_array .iter() - .map(|element| { - if counter != max && element == el { - counter += 1; - Some(false) + .map(|e| { + // Keep first n `false` elements, and reverse other elements to `true`. + if let Some(false) = e { + if count < *n { + count += 1; + e + } else { + Some(true) + } } else { - Some(true) + e } }) - .collect::(); - - let filtered_array = compute::filter(&child_array, &filter_array)?; - values = downcast_arg!( - compute::concat(&[&values, &filtered_array,])?.clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + filtered_array.len() as i32); - } - None => offsets.push(last_offset), + .collect::() + }; + + let filtered_array = arrow::compute::filter(&list_array_row, &eq_array)?; + offsets.push( + offsets[row_index] + OffsetSize::usize_as(filtered_array.len()), + ); + new_values.push(filtered_array); + } + None => { + // Null element results in a null row (no new offsets) + offsets.push(offsets[row_index]); } } + } - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); + let values = if new_values.is_empty() { + new_empty_array(&data_type) + } else { + let new_values = new_values.iter().map(|x| x.as_ref()).collect::>(); + arrow::compute::concat(&new_values)? + }; - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::new(offsets.into()), + values, + list_array.nulls().cloned(), + )?)) } -macro_rules! array_removement_function { - ($FUNC:ident, $MAX_FUNC:expr, $DOC:expr) => { - #[doc = $DOC] - pub fn $FUNC(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let element = &args[1]; - let max = $MAX_FUNC(args)?; - - check_datatypes(stringify!($FUNC), &[arr.values(), element])?; - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - general_remove!(arr, element, max, $ARRAY_TYPE) - }; - } - let res = call_array_function!(arr.value_type(), true); - - Ok(res) +fn array_remove_internal( + array: &ArrayRef, + element_array: &ArrayRef, + arr_n: Vec, +) -> Result { + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_remove::(list_array, element_array, arr_n) } - }; + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_remove::(list_array, element_array, arr_n) + } + array_type => { + exec_err!("array_remove_all does not support type '{array_type:?}'.") + } + } } -fn remove_one(args: &[ArrayRef]) -> Result { - Ok(Int64Array::from_value(1, args[0].len())) +pub fn array_remove_all(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_remove_all expects two arguments"); + } + + let arr_n = vec![i64::MAX; args[0].len()]; + array_remove_internal(&args[0], &args[1], arr_n) } -fn remove_n(args: &[ArrayRef]) -> Result { - as_int64_array(&args[2]).cloned() +pub fn array_remove(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_remove expects two arguments"); + } + + let arr_n = vec![1; args[0].len()]; + array_remove_internal(&args[0], &args[1], arr_n) } -fn remove_all(args: &[ArrayRef]) -> Result { - Ok(Int64Array::from_value(i64::MAX, args[0].len())) +pub fn array_remove_n(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_remove_n expects three arguments"); + } + + let arr_n = as_int64_array(&args[2])?.values().to_vec(); + array_remove_internal(&args[0], &args[1], arr_n) } -// array removement functions -array_removement_function!(array_remove, remove_one, "Array_remove SQL function"); -array_removement_function!(array_remove_n, remove_n, "Array_remove_n SQL function"); -array_removement_function!( - array_remove_all, - remove_all, - "Array_remove_all SQL function" -); +/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences +/// of `from_array[i]`, `to_array[i]`. +/// +/// The type of each **element** in `list_array` must be the same as the type of +/// `from_array` and `to_array`. This function also handles nested arrays +/// ([`ListArray`] of [`ListArray`]s) +/// +/// For example, when called to replace a list array (where each element is a +/// list of int32s, the second and third argument are int32 arrays, and the +/// fourth argument is the number of occurrences to replace +/// +/// ```text +/// general_replace( +/// [1, 2, 3, 2], 2, 10, 1 ==> [1, 10, 3, 2] (only the first 2 is replaced) +/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced) +/// ) +/// ``` +fn general_replace( + list_array: &GenericListArray, + from_array: &ArrayRef, + to_array: &ArrayRef, + arr_n: Vec, +) -> Result { + // Build up the offsets for the final output array + let mut offsets: Vec = vec![O::usize_as(0)]; + let values = list_array.values(); + let original_data = values.to_data(); + let to_data = to_array.to_data(); + let capacity = Capacities::Array(original_data.len()); + + // First array is the original array, second array is the element to replace with. + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &to_data], + false, + capacity, + ); -macro_rules! general_replace { - ($ARRAY:expr, $FROM:expr, $TO:expr, $MAX:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($FROM.data_type()), $ARRAY_TYPE).clone(); + let mut valid = BooleanBufferBuilder::new(list_array.len()); - let from_array = downcast_arg!($FROM, $ARRAY_TYPE); - let to_array = downcast_arg!($TO, $ARRAY_TYPE); - for (((arr, from), to), max) in $ARRAY - .iter() - .zip(from_array.iter()) - .zip(to_array.iter()) - .zip($MAX.iter()) - { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - let mut counter = 0; - let max = if max < Some(1) { 1 } else { max.unwrap() }; - - let replaced_array = child_array - .iter() - .map(|el| { - if counter != max && el == from { - counter += 1; - to - } else { - el - } - }) - .collect::<$ARRAY_TYPE>(); - - values = downcast_arg!( - compute::concat(&[&values, &replaced_array])?.clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + replaced_array.len() as i32); - } - None => { - offsets.push(last_offset); - } - } + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + if list_array.is_null(row_index) { + offsets.push(offsets[row_index]); + valid.append(false); + continue; } - let field = Arc::new(Field::new("item", $FROM.data_type().clone(), true)); + let start = offset_window[0]; + let end = offset_window[1]; - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; -} + let list_array_row = list_array.value(row_index); -macro_rules! general_replace_list { - ($ARRAY:expr, $FROM:expr, $TO:expr, $MAX:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($FROM.data_type()), ListArray).clone(); - - let from_array = downcast_arg!($FROM, ListArray); - let to_array = downcast_arg!($TO, ListArray); - for (((arr, from), to), max) in $ARRAY - .iter() - .zip(from_array.iter()) - .zip(to_array.iter()) - .zip($MAX.iter()) - { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, ListArray); - let mut counter = 0; - let max = if max < Some(1) { 1 } else { max.unwrap() }; - - let replaced_vec = child_array - .iter() - .map(|el| { - if counter != max && el == from { - counter += 1; - to.clone().unwrap() - } else { - el.clone().unwrap() - } - }) - .collect::>(); - - let mut i: i32 = 0; - let mut replaced_offsets = vec![i]; - replaced_offsets.extend( - replaced_vec - .clone() - .into_iter() - .map(|a| { - i += a.len() as i32; - i - }) - .collect::>(), - ); + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let eq_array = + compare_element_to_list(&list_array_row, &from_array, row_index, true)?; - let mut replaced_values = downcast_arg!( - new_empty_array(&from_array.value_type()), - $ARRAY_TYPE - ) - .clone(); - for replaced_list in replaced_vec { - replaced_values = downcast_arg!( - compute::concat(&[&replaced_values, &replaced_list])?, - $ARRAY_TYPE - ) - .clone(); - } + let original_idx = O::usize_as(0); + let replace_idx = O::usize_as(1); + let n = arr_n[row_index]; + let mut counter = 0; - let field = Arc::new(Field::new( - "item", - from_array.value_type().clone(), - true, - )); - let replaced_array = ListArray::try_new( - field, - OffsetBuffer::new(replaced_offsets.clone().into()), - Arc::new(replaced_values), - None, - )?; + // All elements are false, no need to replace, just copy original data + if eq_array.false_count() == eq_array.len() { + mutable.extend( + original_idx.to_usize().unwrap(), + start.to_usize().unwrap(), + end.to_usize().unwrap(), + ); + offsets.push(offsets[row_index] + (end - start)); + valid.append(true); + continue; + } - values = downcast_arg!( - compute::concat(&[&values, &replaced_array,])?.clone(), - ListArray - ) - .clone(); - offsets.push(last_offset + replaced_array.len() as i32); - } - None => { - offsets.push(last_offset); + for (i, to_replace) in eq_array.iter().enumerate() { + let i = O::usize_as(i); + if let Some(true) = to_replace { + mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1); + counter += 1; + if counter == n { + // copy original data for any matches past n + mutable.extend( + original_idx.to_usize().unwrap(), + (start + i).to_usize().unwrap() + 1, + end.to_usize().unwrap(), + ); + break; } + } else { + // copy original data for false / null matches + mutable.extend( + original_idx.to_usize().unwrap(), + (start + i).to_usize().unwrap(), + (start + i).to_usize().unwrap() + 1, + ); } } - let field = Arc::new(Field::new("item", $FROM.data_type().clone(), true)); + offsets.push(offsets[row_index] + (end - start)); + valid.append(true); + } + + let data = mutable.freeze(); - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", list_array.value_type(), true)), + OffsetBuffer::::new(offsets.into()), + arrow_array::make_array(data), + Some(NullBuffer::new(valid.finish())), + )?)) } -macro_rules! array_replacement_function { - ($FUNC:ident, $MAX_FUNC:expr, $DOC:expr) => { - #[doc = $DOC] - pub fn $FUNC(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let from = &args[1]; - let to = &args[2]; - let max = $MAX_FUNC(args)?; - - check_datatypes(stringify!($FUNC), &[arr.values(), from, to])?; - let res = match arr.value_type() { - DataType::List(field) => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - general_replace_list!(arr, from, to, max, $ARRAY_TYPE) - }; - } - call_array_function!(field.data_type(), true) - } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - general_replace!(arr, from, to, max, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) - } - }; +pub fn array_replace(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_replace expects three arguments"); + } - Ok(res) + // replace at most one occurence for each element + let arr_n = vec![1; args[0].len()]; + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) } - }; + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => exec_err!("array_replace does not support type '{array_type:?}'."), + } } -fn replace_one(args: &[ArrayRef]) -> Result { - Ok(Int64Array::from_value(1, args[0].len())) -} +pub fn array_replace_n(args: &[ArrayRef]) -> Result { + if args.len() != 4 { + return exec_err!("array_replace_n expects four arguments"); + } -fn replace_n(args: &[ArrayRef]) -> Result { - as_int64_array(&args[3]).cloned() + // replace the specified number of occurences + let arr_n = as_int64_array(&args[3])?.values().to_vec(); + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => { + exec_err!("array_replace_n does not support type '{array_type:?}'.") + } + } } -fn replace_all(args: &[ArrayRef]) -> Result { - Ok(Int64Array::from_value(i64::MAX, args[0].len())) -} +pub fn array_replace_all(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_replace_all expects three arguments"); + } -// array replacement functions -array_replacement_function!(array_replace, replace_one, "Array_replace SQL function"); -array_replacement_function!(array_replace_n, replace_n, "Array_replace_n SQL function"); -array_replacement_function!( - array_replace_all, - replace_all, - "Array_replace_all SQL function" -); + // replace all occurrences (up to "i64::MAX") + let arr_n = vec![i64::MAX; args[0].len()]; + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => { + exec_err!("array_replace_all does not support type '{array_type:?}'.") + } + } +} macro_rules! to_string { ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ @@ -1491,19 +1830,188 @@ macro_rules! to_string { }}; } +#[derive(Debug, PartialEq)] +enum SetOp { + Union, + Intersect, +} + +impl Display for SetOp { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + SetOp::Union => write!(f, "array_union"), + SetOp::Intersect => write!(f, "array_intersect"), + } + } +} + +fn generic_set_lists( + l: &GenericListArray, + r: &GenericListArray, + field: Arc, + set_op: SetOp, +) -> Result { + if matches!(l.value_type(), DataType::Null) { + let field = Arc::new(Field::new("item", r.value_type(), true)); + return general_array_distinct::(r, &field); + } else if matches!(r.value_type(), DataType::Null) { + let field = Arc::new(Field::new("item", l.value_type(), true)); + return general_array_distinct::(l, &field); + } + + if l.value_type() != r.value_type() { + return internal_err!("{set_op:?} is not implemented for '{l:?}' and '{r:?}'"); + } + + let dt = l.value_type(); + + let mut offsets = vec![OffsetSize::usize_as(0)]; + let mut new_arrays = vec![]; + + let converter = RowConverter::new(vec![SortField::new(dt)])?; + for (first_arr, second_arr) in l.iter().zip(r.iter()) { + if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { + let l_values = converter.convert_columns(&[first_arr])?; + let r_values = converter.convert_columns(&[second_arr])?; + + let l_iter = l_values.iter().sorted().dedup(); + let values_set: HashSet<_> = l_iter.clone().collect(); + let mut rows = if set_op == SetOp::Union { + l_iter.collect::>() + } else { + vec![] + }; + for r_val in r_values.iter().sorted().dedup() { + match set_op { + SetOp::Union => { + if !values_set.contains(&r_val) { + rows.push(r_val); + } + } + SetOp::Intersect => { + if values_set.contains(&r_val) { + rows.push(r_val); + } + } + } + } + + let last_offset = match offsets.last().copied() { + Some(offset) => offset, + None => return internal_err!("offsets should not be empty"), + }; + offsets.push(last_offset + OffsetSize::usize_as(rows.len())); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.first() { + Some(array) => array.clone(), + None => { + return internal_err!("{set_op}: failed to get array from rows"); + } + }; + new_arrays.push(array); + } + } + + let offsets = OffsetBuffer::new(offsets.into()); + let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let values = compute::concat(&new_arrays_ref)?; + let arr = GenericListArray::::try_new(field, offsets, values, None)?; + Ok(Arc::new(arr)) +} + +fn general_set_op( + array1: &ArrayRef, + array2: &ArrayRef, + set_op: SetOp, +) -> Result { + match (array1.data_type(), array2.data_type()) { + (DataType::Null, DataType::List(field)) => { + if set_op == SetOp::Intersect { + return Ok(new_empty_array(&DataType::Null)); + } + let array = as_list_array(&array2)?; + general_array_distinct::(array, field) + } + + (DataType::List(field), DataType::Null) => { + if set_op == SetOp::Intersect { + return make_array(&[]); + } + let array = as_list_array(&array1)?; + general_array_distinct::(array, field) + } + (DataType::Null, DataType::LargeList(field)) => { + if set_op == SetOp::Intersect { + return Ok(new_empty_array(&DataType::Null)); + } + let array = as_large_list_array(&array2)?; + general_array_distinct::(array, field) + } + (DataType::LargeList(field), DataType::Null) => { + if set_op == SetOp::Intersect { + return make_array(&[]); + } + let array = as_large_list_array(&array1)?; + general_array_distinct::(array, field) + } + (DataType::Null, DataType::Null) => Ok(new_empty_array(&DataType::Null)), + + (DataType::List(field), DataType::List(_)) => { + let array1 = as_list_array(&array1)?; + let array2 = as_list_array(&array2)?; + generic_set_lists::(array1, array2, field.clone(), set_op) + } + (DataType::LargeList(field), DataType::LargeList(_)) => { + let array1 = as_large_list_array(&array1)?; + let array2 = as_large_list_array(&array2)?; + generic_set_lists::(array1, array2, field.clone(), set_op) + } + (data_type1, data_type2) => { + internal_err!( + "{set_op} does not support types '{data_type1:?}' and '{data_type2:?}'" + ) + } + } +} + +/// Array_union SQL function +pub fn array_union(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_union needs two arguments"); + } + let array1 = &args[0]; + let array2 = &args[1]; + + general_set_op(array1, array2, SetOp::Union) +} + +/// array_intersect SQL function +pub fn array_intersect(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_intersect needs two arguments"); + } + + let array1 = &args[0]; + let array2 = &args[1]; + + general_set_op(array1, array2, SetOp::Intersect) +} + /// Array_to_string SQL function pub fn array_to_string(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("array_to_string expects two or three arguments"); + } + let arr = &args[0]; - let delimiters = as_generic_string_array::(&args[1])?; + let delimiters = as_string_array(&args[1])?; let delimiters: Vec> = delimiters.iter().collect(); let mut null_string = String::from(""); let mut with_null_string = false; if args.len() == 3 { - null_string = as_generic_string_array::(&args[2])? - .value(0) - .to_string(); + null_string = as_string_array(&args[2])?.value(0).to_string(); with_null_string = true; } @@ -1516,8 +2024,21 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { ) -> Result<&mut String> { match arr.data_type() { DataType::List(..) => { - let list_array = downcast_arg!(arr, ListArray); + let list_array = as_list_array(&arr)?; + for i in 0..list_array.len() { + compute_array_to_string( + arg, + list_array.value(i), + delimiter.clone(), + null_string.clone(), + with_null_string, + )?; + } + Ok(arg) + } + DataType::LargeList(..) => { + let list_array = as_large_list_array(&arr)?; for i in 0..list_array.len() { compute_array_to_string( arg, @@ -1549,35 +2070,61 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { } } - let mut arg = String::from(""); - let mut res: Vec> = Vec::new(); - - match arr.data_type() { - DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => { - let list_array = arr.as_list::(); - for (arr, &delimiter) in list_array.iter().zip(delimiters.iter()) { - if let (Some(arr), Some(delimiter)) = (arr, delimiter) { - arg = String::from(""); - let s = compute_array_to_string( - &mut arg, - arr, - delimiter.to_string(), - null_string.clone(), - with_null_string, - )? - .clone(); - - if let Some(s) = s.strip_suffix(delimiter) { - res.push(Some(s.to_string())); - } else { - res.push(Some(s)); - } + fn generate_string_array( + list_arr: &GenericListArray, + delimiters: Vec>, + null_string: String, + with_null_string: bool, + ) -> Result { + let mut res: Vec> = Vec::new(); + for (arr, &delimiter) in list_arr.iter().zip(delimiters.iter()) { + if let (Some(arr), Some(delimiter)) = (arr, delimiter) { + let mut arg = String::from(""); + let s = compute_array_to_string( + &mut arg, + arr, + delimiter.to_string(), + null_string.clone(), + with_null_string, + )? + .clone(); + + if let Some(s) = s.strip_suffix(delimiter) { + res.push(Some(s.to_string())); } else { - res.push(None); + res.push(Some(s)); } + } else { + res.push(None); } } + + Ok(StringArray::from(res)) + } + + let arr_type = arr.data_type(); + let string_arr = match arr_type { + DataType::List(_) | DataType::FixedSizeList(_, _) => { + let list_array = as_list_array(&arr)?; + generate_string_array::( + list_array, + delimiters, + null_string, + with_null_string, + )? + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&arr)?; + generate_string_array::( + list_array, + delimiters, + null_string, + with_null_string, + )? + } _ => { + let mut arg = String::from(""); + let mut res: Vec> = Vec::new(); // delimiter length is 1 assert_eq!(delimiters.len(), 1); let delimiter = delimiters[0].unwrap(); @@ -1596,24 +2143,44 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { } else { res.push(Some(s)); } + StringArray::from(res) } - } + }; - Ok(Arc::new(StringArray::from(res))) + Ok(Arc::new(string_arr)) } /// Cardinality SQL function pub fn cardinality(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?.clone(); + if args.len() != 1 { + return exec_err!("cardinality expects one argument"); + } - let result = list_array + match &args[0].data_type() { + DataType::List(_) => { + let list_array = as_list_array(&args[0])?; + generic_list_cardinality::(list_array) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&args[0])?; + generic_list_cardinality::(list_array) + } + other => { + exec_err!("cardinality does not support type '{:?}'", other) + } + } +} + +fn generic_list_cardinality( + array: &GenericListArray, +) -> Result { + let result = array .iter() .map(|arr| match compute_array_dims(arr)? { Some(vector) => Ok(Some(vector.iter().map(|x| x.unwrap()).product::())), None => Ok(None), }) .collect::>()?; - Ok(Arc::new(result) as ArrayRef) } @@ -1632,7 +2199,7 @@ fn flatten_internal( indexes: Option>, ) -> Result { let list_arr = as_list_array(array)?; - let (field, offsets, values, nulls) = list_arr.clone().into_parts(); + let (field, offsets, values, _) = list_arr.clone().into_parts(); let data_type = field.data_type(); match data_type { @@ -1649,7 +2216,7 @@ fn flatten_internal( _ => { if let Some(indexes) = indexes { let offsets = get_offsets_for_flatten(offsets, indexes); - let list_arr = ListArray::new(field, offsets, values, nulls); + let list_arr = ListArray::new(field, offsets, values, None); Ok(list_arr) } else { Ok(list_arr.clone()) @@ -1660,15 +2227,19 @@ fn flatten_internal( /// Flatten SQL function pub fn flatten(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("flatten expects one argument"); + } + let flattened_array = flatten_internal(&args[0], None)?; Ok(Arc::new(flattened_array) as ArrayRef) } -/// Array_length SQL function -pub fn array_length(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let dimension = if args.len() == 2 { - as_int64_array(&args[1])?.clone() +/// Dispatch array length computation based on the offset type. +fn array_length_dispatch(array: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&array[0])?; + let dimension = if array.len() == 2 { + as_int64_array(&array[1])?.clone() } else { Int64Array::from_value(1, list_array.len()) }; @@ -1682,14 +2253,45 @@ pub fn array_length(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } +/// Array_length SQL function +pub fn array_length(args: &[ArrayRef]) -> Result { + if args.len() != 1 && args.len() != 2 { + return exec_err!("array_length expects one or two arguments"); + } + + match &args[0].data_type() { + DataType::List(_) => array_length_dispatch::(args), + DataType::LargeList(_) => array_length_dispatch::(args), + array_type => exec_err!("array_length does not support type '{array_type:?}'"), + } +} + /// Array_dims SQL function pub fn array_dims(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; + if args.len() != 1 { + return exec_err!("array_dims needs one argument"); + } + + let data = match args[0].data_type() { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + array + .iter() + .map(compute_array_dims) + .collect::>>()? + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + array + .iter() + .map(compute_array_dims) + .collect::>>()? + } + array_type => { + return exec_err!("array_dims does not support type '{array_type:?}'"); + } + }; - let data = list_array - .iter() - .map(compute_array_dims) - .collect::>>()?; let result = ListArray::from_iter_primitive::(data); Ok(Arc::new(result) as ArrayRef) @@ -1697,171 +2299,165 @@ pub fn array_dims(args: &[ArrayRef]) -> Result { /// Array_ndims SQL function pub fn array_ndims(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - - let result = list_array - .iter() - .map(compute_array_ndims) - .collect::>()?; - - Ok(Arc::new(result) as ArrayRef) -} + if args.len() != 1 { + return exec_err!("array_ndims needs one argument"); + } -macro_rules! non_list_contains { - ($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{ - let sub_array = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE); - let mut boolean_builder = BooleanArray::builder($ARRAY.len()); + fn general_list_ndims( + array: &GenericListArray, + ) -> Result { + let mut data = Vec::new(); + let ndims = datafusion_common::utils::list_ndims(array.data_type()); - for (arr, elem) in $ARRAY.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(elem)) = (arr, elem) { - let arr = downcast_arg!(arr, $ARRAY_TYPE); - let res = arr.iter().dedup().flatten().any(|x| x == elem); - boolean_builder.append_value(res); + for arr in array.iter() { + if arr.is_some() { + data.push(Some(ndims)) + } else { + data.push(None) } } - Ok(Arc::new(boolean_builder.finish())) - }}; -} -/// Array_has SQL function -pub fn array_has(args: &[ArrayRef]) -> Result { - let array = as_list_array(&args[0])?; - let element = &args[1]; + Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) + } - check_datatypes("array_has", &[array.values(), element])?; - match element.data_type() { + match args[0].data_type() { DataType::List(_) => { - let sub_array = as_list_array(element)?; - let mut boolean_builder = BooleanArray::builder(array.len()); - - for (arr, elem) in array.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(elem)) = (arr, elem) { - let list_arr = as_list_array(&arr)?; - let res = list_arr.iter().dedup().flatten().any(|x| *x == *elem); - boolean_builder.append_value(res); - } - } - Ok(Arc::new(boolean_builder.finish())) + let array = as_list_array(&args[0])?; + general_list_ndims::(array) } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - non_list_contains!(array, element, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_list_ndims::(array) } + _ => Ok(Arc::new(UInt64Array::from(vec![0; args[0].len()])) as ArrayRef), } } -macro_rules! array_has_any_non_list_check { - ($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - let sub_arr = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE); - - let mut res = false; - for elem in sub_arr.iter().dedup() { - if let Some(elem) = elem { - res |= arr.iter().dedup().flatten().any(|x| x == elem); - } else { - return internal_err!( - "array_has_any does not support Null type for element in sub_array" - ); - } - } - res - }}; +/// Represents the type of comparison for array_has. +#[derive(Debug, PartialEq)] +enum ComparisonType { + // array_has_all + All, + // array_has_any + Any, + // array_has + Single, } -/// Array_has_any SQL function -pub fn array_has_any(args: &[ArrayRef]) -> Result { - check_datatypes("array_has_any", &[&args[0], &args[1]])?; - - let array = as_list_array(&args[0])?; - let sub_array = as_list_array(&args[1])?; +fn general_array_has_dispatch( + array: &ArrayRef, + sub_array: &ArrayRef, + comparison_type: ComparisonType, +) -> Result { + let array = if comparison_type == ComparisonType::Single { + let arr = as_generic_list_array::(array)?; + check_datatypes("array_has", &[arr.values(), sub_array])?; + arr + } else { + check_datatypes("array_has", &[array, sub_array])?; + as_generic_list_array::(array)? + }; let mut boolean_builder = BooleanArray::builder(array.len()); - for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { + + let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; + + let element = sub_array.clone(); + let sub_array = if comparison_type != ComparisonType::Single { + as_generic_list_array::(sub_array)? + } else { + array + }; + + for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() { if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let res = match arr.data_type() { - DataType::List(_) => { - let arr = downcast_arg!(arr, ListArray); - let sub_arr = downcast_arg!(sub_arr, ListArray); - - let mut res = false; - for elem in sub_arr.iter().dedup().flatten() { - res |= arr.iter().dedup().flatten().any(|x| *x == *elem); - } - res - } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - array_has_any_non_list_check!(arr, sub_arr, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) - } + let arr_values = converter.convert_columns(&[arr])?; + let sub_arr_values = if comparison_type != ComparisonType::Single { + converter.convert_columns(&[sub_arr])? + } else { + converter.convert_columns(&[element.clone()])? + }; + + let mut res = match comparison_type { + ComparisonType::All => sub_arr_values + .iter() + .dedup() + .all(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Any => sub_arr_values + .iter() + .dedup() + .any(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Single => arr_values + .iter() + .dedup() + .any(|x| x == sub_arr_values.row(row_idx)), }; + + if comparison_type == ComparisonType::Any { + res |= res; + } + boolean_builder.append_value(res); } } Ok(Arc::new(boolean_builder.finish())) } -macro_rules! array_has_all_non_list_check { - ($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - let sub_arr = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE); +/// Array_has SQL function +pub fn array_has(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_has needs two arguments"); + } - let mut res = true; - for elem in sub_arr.iter().dedup() { - if let Some(elem) = elem { - res &= arr.iter().dedup().flatten().any(|x| x == elem); - } else { - return internal_err!( - "array_has_all does not support Null type for element in sub_array" - ); - } + let array_type = args[0].data_type(); + + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) } - res - }}; + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) + } + _ => exec_err!("array_has does not support type '{array_type:?}'."), + } } -/// Array_has_all SQL function -pub fn array_has_all(args: &[ArrayRef]) -> Result { - check_datatypes("array_has_all", &[&args[0], &args[1]])?; - - let array = as_list_array(&args[0])?; - let sub_array = as_list_array(&args[1])?; - - let mut boolean_builder = BooleanArray::builder(array.len()); - for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let res = match arr.data_type() { - DataType::List(_) => { - let arr = downcast_arg!(arr, ListArray); - let sub_arr = downcast_arg!(sub_arr, ListArray); - - let mut res = true; - for elem in sub_arr.iter().dedup().flatten() { - res &= arr.iter().dedup().flatten().any(|x| *x == *elem); - } - res - } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - array_has_all_non_list_check!(arr, sub_arr, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) - } - }; - boolean_builder.append_value(res); +/// Array_has_any SQL function +pub fn array_has_any(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_has_any needs two arguments"); + } + + let array_type = args[0].data_type(); + + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) + } + _ => exec_err!("array_has_any does not support type '{array_type:?}'."), + } +} + +/// Array_has_all SQL function +pub fn array_has_all(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_has_all needs two arguments"); + } + + let array_type = args[0].data_type(); + + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) + } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) + } + _ => exec_err!("array_has_all does not support type '{array_type:?}'."), } - Ok(Arc::new(boolean_builder.finish())) } /// Splits string at occurrences of delimiter and returns an array of parts @@ -1943,7 +2539,7 @@ pub fn string_to_array(args: &[ArrayRef]) -> Result { - return internal_err!( + return exec_err!( "Expect string_to_array function to take two or three parameters" ) } @@ -1953,1536 +2549,112 @@ pub fn string_to_array(args: &[ArrayRef]) -> Result() - .unwrap() - .values() - ) - } - - #[test] - fn test_nested_array() { - // make_array([1, 3, 5], [2, 4, 6]) = [[1, 3, 5], [2, 4, 6]] - let args = [ - ColumnarValue::Array(Arc::new(Int64Array::from(vec![1, 2]))), - ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 4]))), - ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 6]))), - ]; - let array = array(&args) - .expect("failed to initialize function array") - .into_array(1); - let result = as_list_array(&array).expect("failed to initialize function array"); - assert_eq!(result.len(), 2); - assert_eq!( - &[1, 3, 5], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - assert_eq!( - &[2, 4, 6], - result - .value(1) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_element() { - // array_element([1, 2, 3, 4], 1) = 1 - let list_array = return_array().into_array(1); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(1, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from_value(1, 1)); - - // array_element([1, 2, 3, 4], 3) = 3 - let list_array = return_array().into_array(1); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from_value(3, 1)); - - // array_element([1, 2, 3, 4], 0) = NULL - let list_array = return_array().into_array(1); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(0, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from(vec![None])); - - // array_element([1, 2, 3, 4], NULL) = NULL - let list_array = return_array().into_array(1); - let arr = array_element(&[list_array, Arc::new(Int64Array::from(vec![None]))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from(vec![None])); - - // array_element([1, 2, 3, 4], -1) = 4 - let list_array = return_array().into_array(1); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(-1, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from_value(4, 1)); - - // array_element([1, 2, 3, 4], -3) = 2 - let list_array = return_array().into_array(1); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(-3, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from_value(2, 1)); - - // array_element([1, 2, 3, 4], 10) = NULL - let list_array = return_array().into_array(1); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(10, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from(vec![None])); - } - - #[test] - fn test_nested_array_element() { - // array_element([[1, 2, 3, 4], [5, 6, 7, 8]], 2) = [5, 6, 7, 8] - let list_array = return_nested_array().into_array(1); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(2, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_list_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!( - &[5, 6, 7, 8], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_pop_back() { - // array_pop_back([1, 2, 3, 4]) = [1, 2, 3] - let list_array = return_array().into_array(1); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!( - &[1, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_pop_back([1, 2, 3]) = [1, 2] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!( - &[1, 2], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_pop_back([1, 2]) = [1] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!( - &[1], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_pop_back([1]) = [] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!( - &[], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - // array_pop_back([]) = [] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!( - &[], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_pop_back([1, NULL, 3, NULL]) = [1, NULL, 3] - let list_array = return_array_with_nulls().into_array(1); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!(3, result.values().len()); - assert_eq!( - &[false, true, false], - &[ - result.values().is_null(0), - result.values().is_null(1), - result.values().is_null(2) - ] - ); - } - #[test] - fn test_nested_array_pop_back() { - // array_pop_back([[1, 2, 3, 4], [5, 6, 7, 8]]) = [[1, 2, 3, 4]] - let list_array = return_nested_array().into_array(1); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_pop_back([[1, 2, 3, 4]]) = [] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - // array_pop_back([]) = [] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - } - - #[test] - fn test_array_slice() { - // array_slice([1, 2, 3, 4], 1, 3) = [1, 2, 3] - let list_array = return_array().into_array(1); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(1, 1)), - Arc::new(Int64Array::from_value(3, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[1, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], 2, 2) = [2] - let list_array = return_array().into_array(1); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(2, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[2], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], 0, 0) = [] - let list_array = return_array().into_array(1); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(0, 1)), - Arc::new(Int64Array::from_value(0, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - - // array_slice([1, 2, 3, 4], 0, 6) = [1, 2, 3, 4] - let list_array = return_array().into_array(1); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(0, 1)), - Arc::new(Int64Array::from_value(6, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], -2, -2) = [] - let list_array = return_array().into_array(1); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-2, 1)), - Arc::new(Int64Array::from_value(-2, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - - // array_slice([1, 2, 3, 4], -3, -1) = [2, 3] - let list_array = return_array().into_array(1); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-3, 1)), - Arc::new(Int64Array::from_value(-1, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], -3, 2) = [2] - let list_array = return_array().into_array(1); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-3, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[2], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], 2, 11) = [2, 3, 4] - let list_array = return_array().into_array(1); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(2, 1)), - Arc::new(Int64Array::from_value(11, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], 3, 1) = [] - let list_array = return_array().into_array(1); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(1, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - - // array_slice([1, 2, 3, 4], -7, -2) = NULL - let list_array = return_array().into_array(1); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-7, 1)), - Arc::new(Int64Array::from_value(-2, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_null(0)); - } - - #[test] - fn test_nested_array_slice() { - // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], 1, 1) = [[1, 2, 3, 4]] - let list_array = return_nested_array().into_array(1); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(1, 1)), - Arc::new(Int64Array::from_value(1, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], -1, -1) = [] - let list_array = return_nested_array().into_array(1); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-1, 1)), - Arc::new(Int64Array::from_value(-1, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - - // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], -1, 2) = [[5, 6, 7, 8]] - let list_array = return_nested_array().into_array(1); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-1, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[5, 6, 7, 8], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_append() { - // array_append([1, 2, 3], 4) = [1, 2, 3, 4] - let data = vec![Some(vec![Some(1), Some(2), Some(3)])]; - let list_array = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let int64_array = Arc::new(Int64Array::from(vec![Some(4)])) as ArrayRef; - - let args = [list_array, int64_array]; - - let array = - array_append(&args).expect("failed to initialize function array_append"); - let result = - as_list_array(&array).expect("failed to initialize function array_append"); - - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_prepend() { - // array_prepend(1, [2, 3, 4]) = [1, 2, 3, 4] - let data = vec![Some(vec![Some(2), Some(3), Some(4)])]; - let list_array = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let int64_array = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef; - - let args = [int64_array, list_array]; - - let array = - array_prepend(&args).expect("failed to initialize function array_append"); - let result = - as_list_array(&array).expect("failed to initialize function array_append"); - - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_concat() { - // array_concat([1, 2, 3], [4, 5, 6], [7, 8, 9]) = [1, 2, 3, 4, 5, 6, 7, 8, 9] - let data = vec![Some(vec![Some(1), Some(2), Some(3)])]; - let list_array1 = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let data = vec![Some(vec![Some(4), Some(5), Some(6)])]; - let list_array2 = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let data = vec![Some(vec![Some(7), Some(8), Some(9)])]; - let list_array3 = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - - let args = [list_array1, list_array2, list_array3]; - - let array = - array_concat(&args).expect("failed to initialize function array_concat"); - let result = - as_list_array(&array).expect("failed to initialize function array_concat"); - - assert_eq!( - &[1, 2, 3, 4, 5, 6, 7, 8, 9], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_concat() { - // array_concat([1, 2, 3, 4], [1, 2, 3, 4]) = [1, 2, 3, 4, 1, 2, 3, 4] - let list_array = return_array().into_array(1); - let arr = array_concat(&[list_array.clone(), list_array.clone()]) - .expect("failed to initialize function array_concat"); - let result = - as_list_array(&arr).expect("failed to initialize function array_concat"); - - assert_eq!( - &[1, 2, 3, 4, 1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_concat([[1, 2, 3, 4], [5, 6, 7, 8]], [1, 2, 3, 4]) = [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4]] - let list_nested_array = return_nested_array().into_array(1); - let list_array = return_array().into_array(1); - let arr = array_concat(&[list_nested_array, list_array]) - .expect("failed to initialize function array_concat"); - let result = - as_list_array(&arr).expect("failed to initialize function array_concat"); - - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(2) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_position() { - // array_position([1, 2, 3, 4], 3) = 3 - let list_array = return_array().into_array(1); - let array = array_position(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_position"); - let result = as_uint64_array(&array) - .expect("failed to initialize function array_position"); - - assert_eq!(result, &UInt64Array::from(vec![3])); - } - - #[test] - fn test_array_positions() { - // array_positions([1, 2, 3, 4], 3) = [3] - let list_array = return_array().into_array(1); - let array = - array_positions(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_position"); - let result = - as_list_array(&array).expect("failed to initialize function array_position"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_remove() { - // array_remove([3, 1, 2, 3, 2, 3], 3) = [1, 2, 3, 2, 3] - let list_array = return_array_with_repeating_elements().into_array(1); - let array = array_remove(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_remove"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 3, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_remove() { - // array_remove( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // ) = [[5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements().into_array(1); - let element_array = return_array().into_array(1); - let array = array_remove(&[list_array, element_array]) - .expect("failed to initialize function array_remove"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_remove_n() { - // array_remove_n([3, 1, 2, 3, 2, 3], 3, 2) = [1, 2, 2, 3] - let list_array = return_array_with_repeating_elements().into_array(1); - let array = array_remove_n(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_remove_n"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove_n"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_remove_n() { - // array_remove_n( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // 3, - // ) = [[5, 6, 7, 8], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements().into_array(1); - let element_array = return_array().into_array(1); - let array = array_remove_n(&[ - list_array, - element_array, - Arc::new(Int64Array::from_value(3, 1)), - ]) - .expect("failed to initialize function array_remove_n"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove_n"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_remove_all() { - // array_remove_all([3, 1, 2, 3, 2, 3], 3) = [1, 2, 2] - let list_array = return_array_with_repeating_elements().into_array(1); - let array = - array_remove_all(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_remove_all"); - let result = as_list_array(&array) - .expect("failed to initialize function array_remove_all"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 2], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_remove_all() { - // array_remove_all( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // ) = [[5, 6, 7, 8], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements().into_array(1); - let element_array = return_array().into_array(1); - let array = array_remove_all(&[list_array, element_array]) - .expect("failed to initialize function array_remove_all"); - let result = as_list_array(&array) - .expect("failed to initialize function array_remove_all"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_replace() { - // array_replace([3, 1, 2, 3, 2, 3], 3, 4) = [4, 1, 2, 3, 2, 3] - let list_array = return_array_with_repeating_elements().into_array(1); - let array = array_replace(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(4, 1)), - ]) - .expect("failed to initialize function array_replace"); - let result = - as_list_array(&array).expect("failed to initialize function array_replace"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[4, 1, 2, 3, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_replace() { - // array_replace( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // [11, 12, 13, 14], - // ) = [[11, 12, 13, 14], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements().into_array(1); - let from_array = return_array().into_array(1); - let to_array = return_extra_array().into_array(1); - let array = array_replace(&[list_array, from_array, to_array]) - .expect("failed to initialize function array_replace"); - let result = - as_list_array(&array).expect("failed to initialize function array_replace"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(11), Some(12), Some(13), Some(14)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_replace_n() { - // array_replace_n([3, 1, 2, 3, 2, 3], 3, 4, 2) = [4, 1, 2, 4, 2, 3] - let list_array = return_array_with_repeating_elements().into_array(1); - let array = array_replace_n(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(4, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_replace_n"); - let result = - as_list_array(&array).expect("failed to initialize function array_replace_n"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[4, 1, 2, 4, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_replace_n() { - // array_replace_n( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // [11, 12, 13, 14], - // 2, - // ) = [[11, 12, 13, 14], [5, 6, 7, 8], [11, 12, 13, 14], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements().into_array(1); - let from_array = return_array().into_array(1); - let to_array = return_extra_array().into_array(1); - let array = array_replace_n(&[ - list_array, - from_array, - to_array, - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_replace_n"); - let result = - as_list_array(&array).expect("failed to initialize function array_replace_n"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(11), Some(12), Some(13), Some(14)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(11), Some(12), Some(13), Some(14)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_replace_all() { - // array_replace_all([3, 1, 2, 3, 2, 3], 3, 4) = [4, 1, 2, 4, 2, 4] - let list_array = return_array_with_repeating_elements().into_array(1); - let array = array_replace_all(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(4, 1)), - ]) - .expect("failed to initialize function array_replace_all"); - let result = as_list_array(&array) - .expect("failed to initialize function array_replace_all"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[4, 1, 2, 4, 2, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_replace_all() { - // array_replace_all( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // [11, 12, 13, 14], - // ) = [[11, 12, 13, 14], [5, 6, 7, 8], [11, 12, 13, 14], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements().into_array(1); - let from_array = return_array().into_array(1); - let to_array = return_extra_array().into_array(1); - let array = array_replace_all(&[list_array, from_array, to_array]) - .expect("failed to initialize function array_replace_all"); - let result = as_list_array(&array) - .expect("failed to initialize function array_replace_all"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(11), Some(12), Some(13), Some(14)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(11), Some(12), Some(13), Some(14)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_repeat() { - // array_repeat(3, 5) = [3, 3, 3, 3, 3] - let array = array_repeat(&[ - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(5, 1)), - ]) - .expect("failed to initialize function array_repeat"); - let result = - as_list_array(&array).expect("failed to initialize function array_repeat"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[3, 3, 3, 3, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_repeat() { - // array_repeat([1, 2, 3, 4], 3) = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] - let element = return_array().into_array(1); - let array = array_repeat(&[element, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_repeat"); - let result = - as_list_array(&array).expect("failed to initialize function array_repeat"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - #[test] - fn test_array_to_string() { - // array_to_string([1, 2, 3, 4], ',') = 1,2,3,4 - let list_array = return_array().into_array(1); - let array = - array_to_string(&[list_array, Arc::new(StringArray::from(vec![Some(",")]))]) - .expect("failed to initialize function array_to_string"); - let result = as_generic_string_array::(&array) - .expect("failed to initialize function array_to_string"); - - assert_eq!(result.len(), 1); - assert_eq!("1,2,3,4", result.value(0)); - - // array_to_string([1, NULL, 3, NULL], ',', '*') = 1,*,3,* - let list_array = return_array_with_nulls().into_array(1); - let array = array_to_string(&[ - list_array, - Arc::new(StringArray::from(vec![Some(",")])), - Arc::new(StringArray::from(vec![Some("*")])), - ]) - .expect("failed to initialize function array_to_string"); - let result = as_generic_string_array::(&array) - .expect("failed to initialize function array_to_string"); - - assert_eq!(result.len(), 1); - assert_eq!("1,*,3,*", result.value(0)); - } - - #[test] - fn test_nested_array_to_string() { - // array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], '-') = 1-2-3-4-5-6-7-8 - let list_array = return_nested_array().into_array(1); - let array = - array_to_string(&[list_array, Arc::new(StringArray::from(vec![Some("-")]))]) - .expect("failed to initialize function array_to_string"); - let result = as_generic_string_array::(&array) - .expect("failed to initialize function array_to_string"); - - assert_eq!(result.len(), 1); - assert_eq!("1-2-3-4-5-6-7-8", result.value(0)); - - // array_to_string([[1, NULL, 3, NULL], [NULL, 6, 7, NULL]], '-', '*') = 1-*-3-*-*-6-7-* - let list_array = return_nested_array_with_nulls().into_array(1); - let array = array_to_string(&[ - list_array, - Arc::new(StringArray::from(vec![Some("-")])), - Arc::new(StringArray::from(vec![Some("*")])), - ]) - .expect("failed to initialize function array_to_string"); - let result = as_generic_string_array::(&array) - .expect("failed to initialize function array_to_string"); - - assert_eq!(result.len(), 1); - assert_eq!("1-*-3-*-*-6-7-*", result.value(0)); +pub fn general_array_distinct( + array: &GenericListArray, + field: &FieldRef, +) -> Result { + let dt = array.value_type(); + let mut offsets = Vec::with_capacity(array.len()); + offsets.push(OffsetSize::usize_as(0)); + let mut new_arrays = Vec::with_capacity(array.len()); + let converter = RowConverter::new(vec![SortField::new(dt)])?; + // distinct for each list in ListArray + for arr in array.iter().flatten() { + let values = converter.convert_columns(&[arr])?; + // sort elements in list and remove duplicates + let rows = values.iter().sorted().dedup().collect::>(); + let last_offset: OffsetSize = offsets.last().copied().unwrap(); + offsets.push(last_offset + OffsetSize::usize_as(rows.len())); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.first() { + Some(array) => array.clone(), + None => { + return internal_err!("array_distinct: failed to get array from rows") + } + }; + new_arrays.push(array); } + let offsets = OffsetBuffer::new(offsets.into()); + let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let values = compute::concat(&new_arrays_ref)?; + Ok(Arc::new(GenericListArray::::try_new( + field.clone(), + offsets, + values, + None, + )?)) +} - #[test] - fn test_cardinality() { - // cardinality([1, 2, 3, 4]) = 4 - let list_array = return_array().into_array(1); - let arr = cardinality(&[list_array]) - .expect("failed to initialize function cardinality"); - let result = - as_uint64_array(&arr).expect("failed to initialize function cardinality"); - - assert_eq!(result, &UInt64Array::from(vec![4])); +/// array_distinct SQL function +/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4] +pub fn array_distinct(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_distinct needs one argument"); } - #[test] - fn test_nested_cardinality() { - // cardinality([[1, 2, 3, 4], [5, 6, 7, 8]]) = 8 - let list_array = return_nested_array().into_array(1); - let arr = cardinality(&[list_array]) - .expect("failed to initialize function cardinality"); - let result = - as_uint64_array(&arr).expect("failed to initialize function cardinality"); - - assert_eq!(result, &UInt64Array::from(vec![8])); + // handle null + if args[0].data_type() == &DataType::Null { + return Ok(args[0].clone()); } - #[test] - fn test_array_length() { - // array_length([1, 2, 3, 4]) = 4 - let list_array = return_array().into_array(1); - let arr = array_length(&[list_array.clone()]) - .expect("failed to initialize function array_ndims"); - let result = - as_uint64_array(&arr).expect("failed to initialize function array_ndims"); - - assert_eq!(result, &UInt64Array::from_value(4, 1)); - - // array_length([1, 2, 3, 4], 1) = 4 - let array = array_length(&[list_array, Arc::new(Int64Array::from_value(1, 1))]) - .expect("failed to initialize function array_ndims"); - let result = - as_uint64_array(&array).expect("failed to initialize function array_ndims"); - - assert_eq!(result, &UInt64Array::from_value(4, 1)); + // handle for list & largelist + match args[0].data_type() { + DataType::List(field) => { + let array = as_list_array(&args[0])?; + general_array_distinct(array, field) + } + DataType::LargeList(field) => { + let array = as_large_list_array(&args[0])?; + general_array_distinct(array, field) + } + array_type => exec_err!("array_distinct does not support type '{array_type:?}'"), } +} - #[test] - fn test_nested_array_length() { - let list_array = return_nested_array().into_array(1); - - // array_length([[1, 2, 3, 4], [5, 6, 7, 8]]) = 2 - let arr = array_length(&[list_array.clone()]) - .expect("failed to initialize function array_length"); - let result = - as_uint64_array(&arr).expect("failed to initialize function array_length"); - - assert_eq!(result, &UInt64Array::from_value(2, 1)); - - // array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 1) = 2 - let arr = - array_length(&[list_array.clone(), Arc::new(Int64Array::from_value(1, 1))]) - .expect("failed to initialize function array_length"); - let result = - as_uint64_array(&arr).expect("failed to initialize function array_length"); - - assert_eq!(result, &UInt64Array::from_value(2, 1)); - - // array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 2) = 4 - let arr = - array_length(&[list_array.clone(), Arc::new(Int64Array::from_value(2, 1))]) - .expect("failed to initialize function array_length"); - let result = - as_uint64_array(&arr).expect("failed to initialize function array_length"); - - assert_eq!(result, &UInt64Array::from_value(4, 1)); - - // array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 3) = NULL - let arr = array_length(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_length"); - let result = - as_uint64_array(&arr).expect("failed to initialize function array_length"); - - assert_eq!(result, &UInt64Array::from(vec![None])); - } +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::Int64Type; + /// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt` #[test] - fn test_array_dims() { - // array_dims([1, 2, 3, 4]) = [4] - let list_array = return_array().into_array(1); - - let array = - array_dims(&[list_array]).expect("failed to initialize function array_dims"); - let result = - as_list_array(&array).expect("failed to initialize function array_dims"); - + fn test_align_array_dimensions() { + let array1d_1 = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + ])); + let array1d_2 = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(6), Some(7), Some(8)]), + ])); + + let array2d_1 = Arc::new(array_into_list_array(array1d_1.clone())) as ArrayRef; + let array2d_2 = Arc::new(array_into_list_array(array1d_2.clone())) as ArrayRef; + + let res = + align_array_dimensions(vec![array1d_1.to_owned(), array2d_2.to_owned()]) + .unwrap(); + + let expected = as_list_array(&array2d_1).unwrap(); + let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type()); + assert_ne!(as_list_array(&res[0]).unwrap(), expected); assert_eq!( - &[4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() + datafusion_common::utils::list_ndims(res[0].data_type()), + expected_dim ); - } - #[test] - fn test_nested_array_dims() { - // array_dims([[1, 2, 3, 4], [5, 6, 7, 8]]) = [2, 4] - let list_array = return_nested_array().into_array(1); - - let array = - array_dims(&[list_array]).expect("failed to initialize function array_dims"); - let result = - as_list_array(&array).expect("failed to initialize function array_dims"); + let array3d_1 = Arc::new(array_into_list_array(array2d_1)) as ArrayRef; + let array3d_2 = array_into_list_array(array2d_2.to_owned()); + let res = + align_array_dimensions(vec![array1d_1, Arc::new(array3d_2.clone())]).unwrap(); + let expected = as_list_array(&array3d_1).unwrap(); + let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type()); + assert_ne!(as_list_array(&res[0]).unwrap(), expected); assert_eq!( - &[2, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() + datafusion_common::utils::list_ndims(res[0].data_type()), + expected_dim ); } - - #[test] - fn test_array_ndims() { - // array_ndims([1, 2, 3, 4]) = 1 - let list_array = return_array().into_array(1); - - let array = array_ndims(&[list_array]) - .expect("failed to initialize function array_ndims"); - let result = - as_uint64_array(&array).expect("failed to initialize function array_ndims"); - - assert_eq!(result, &UInt64Array::from_value(1, 1)); - } - - #[test] - fn test_nested_array_ndims() { - // array_ndims([[1, 2, 3, 4], [5, 6, 7, 8]]) = 2 - let list_array = return_nested_array().into_array(1); - - let array = array_ndims(&[list_array]) - .expect("failed to initialize function array_ndims"); - let result = - as_uint64_array(&array).expect("failed to initialize function array_ndims"); - - assert_eq!(result, &UInt64Array::from_value(2, 1)); - } - - #[test] - fn test_check_invalid_datatypes() { - let data = vec![Some(vec![Some(1), Some(2), Some(3)])]; - let list_array = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let int64_array = Arc::new(StringArray::from(vec![Some("string")])) as ArrayRef; - - let args = [list_array.clone(), int64_array.clone()]; - - let array = array_append(&args); - - assert_eq!(array.unwrap_err().strip_backtrace(), "Error during planning: array_append received incompatible types: '[Int64, Utf8]'."); - } - - fn return_array() -> ColumnarValue { - // Returns: [1, 2, 3, 4] - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), - ]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); - ColumnarValue::Array(result.clone()) - } - - fn return_extra_array() -> ColumnarValue { - // Returns: [11, 12, 13, 14] - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(11))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(12))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(13))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(14))), - ]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); - ColumnarValue::Array(result.clone()) - } - - fn return_nested_array() -> ColumnarValue { - // Returns: [[1, 2, 3, 4], [5, 6, 7, 8]] - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), - ]; - let arr1 = array(&args) - .expect("failed to initialize function array") - .into_array(1); - - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(6))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(7))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(8))), - ]; - let arr2 = array(&args) - .expect("failed to initialize function array") - .into_array(1); - - let args = [ColumnarValue::Array(arr1), ColumnarValue::Array(arr2)]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); - ColumnarValue::Array(result.clone()) - } - - fn return_array_with_nulls() -> ColumnarValue { - // Returns: [1, NULL, 3, NULL] - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::Null), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ColumnarValue::Scalar(ScalarValue::Null), - ]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); - ColumnarValue::Array(result.clone()) - } - - fn return_nested_array_with_nulls() -> ColumnarValue { - // Returns: [[1, NULL, 3, NULL], [NULL, 6, 7, NULL]] - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::Null), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ColumnarValue::Scalar(ScalarValue::Null), - ]; - let arr1 = array(&args) - .expect("failed to initialize function array") - .into_array(1); - - let args = [ - ColumnarValue::Scalar(ScalarValue::Null), - ColumnarValue::Scalar(ScalarValue::Int64(Some(6))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(7))), - ColumnarValue::Scalar(ScalarValue::Null), - ]; - let arr2 = array(&args) - .expect("failed to initialize function array") - .into_array(1); - - let args = [ColumnarValue::Array(arr1), ColumnarValue::Array(arr2)]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); - ColumnarValue::Array(result.clone()) - } - - fn return_array_with_repeating_elements() -> ColumnarValue { - // Returns: [3, 1, 2, 3, 2, 3] - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); - ColumnarValue::Array(result.clone()) - } - - fn return_nested_array_with_repeating_elements() -> ColumnarValue { - // Returns: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]] - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), - ]; - let arr1 = array(&args) - .expect("failed to initialize function array") - .into_array(1); - - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(6))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(7))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(8))), - ]; - let arr2 = array(&args) - .expect("failed to initialize function array") - .into_array(1); - - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), - ]; - let arr3 = array(&args) - .expect("failed to initialize function array") - .into_array(1); - - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(9))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(10))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(11))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(12))), - ]; - let arr4 = array(&args) - .expect("failed to initialize function array") - .into_array(1); - - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(6))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(7))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(8))), - ]; - let arr5 = array(&args) - .expect("failed to initialize function array") - .into_array(1); - - let args = [ - ColumnarValue::Array(arr1), - ColumnarValue::Array(arr2), - ColumnarValue::Array(arr3), - ColumnarValue::Array(arr4), - ColumnarValue::Array(arr5), - ]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); - ColumnarValue::Array(result.clone()) - } } diff --git a/datafusion/physical-expr/src/conditional_expressions.rs b/datafusion/physical-expr/src/conditional_expressions.rs index 37adb2d71ce8..a9a25ffe2ec1 100644 --- a/datafusion/physical-expr/src/conditional_expressions.rs +++ b/datafusion/physical-expr/src/conditional_expressions.rs @@ -54,7 +54,7 @@ pub fn coalesce(args: &[ColumnarValue]) -> Result { if value.is_null() { continue; } else { - let last_value = value.to_array_of_size(size); + let last_value = value.to_array_of_size(size)?; current_value = zip(&remainder, &last_value, current_value.as_ref())?; break; diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index 5ce71f4584bb..589bbc8a952b 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -17,14 +17,9 @@ //! DateTime expressions -use arrow::array::Float64Builder; +use crate::datetime_expressions; +use crate::expressions::cast_column; use arrow::compute::cast; -use arrow::{ - array::TimestampNanosecondArray, - compute::kernels::temporal, - datatypes::TimeUnit, - temporal_conversions::{as_datetime_with_timezone, timestamp_ns_to_datetime}, -}; use arrow::{ array::{Array, ArrayRef, Float64Array, OffsetSizeTrait, PrimitiveArray}, compute::kernels::cast_utils::string_to_timestamp_nanos, @@ -34,14 +29,18 @@ use arrow::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }, }; -use arrow_array::{ - timezone::Tz, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampSecondArray, +use arrow::{ + compute::kernels::temporal, + datatypes::TimeUnit, + temporal_conversions::{as_datetime_with_timezone, timestamp_ns_to_datetime}, }; +use arrow_array::temporal_conversions::NANOSECONDS; +use arrow_array::timezone::Tz; +use arrow_array::types::ArrowTimestampType; use chrono::prelude::*; use chrono::{Duration, Months, NaiveDate}; use datafusion_common::cast::{ - as_date32_array, as_date64_array, as_generic_string_array, + as_date32_array, as_date64_array, as_generic_string_array, as_primitive_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, as_timestamp_second_array, }; @@ -128,6 +127,10 @@ fn string_to_timestamp_nanos_shim(s: &str) -> Result { } /// to_timestamp SQL function +/// +/// Note: `to_timestamp` returns `Timestamp(Nanosecond)` though its arguments are interpreted as **seconds**. The supported range for integer input is between `-9223372037` and `9223372036`. +/// Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. +/// Please use `to_timestamp_seconds` for the input outside of supported bounds. pub fn to_timestamp(args: &[ColumnarValue]) -> Result { handle::( args, @@ -154,6 +157,15 @@ pub fn to_timestamp_micros(args: &[ColumnarValue]) -> Result { ) } +/// to_timestamp_nanos SQL function +pub fn to_timestamp_nanos(args: &[ColumnarValue]) -> Result { + handle::( + args, + string_to_timestamp_nanos_shim, + "to_timestamp_nanos", + ) +} + /// to_timestamp_seconds SQL function pub fn to_timestamp_seconds(args: &[ColumnarValue]) -> Result { handle::( @@ -320,7 +332,7 @@ fn date_trunc_coarse(granularity: &str, value: i64, tz: Option) -> Result, tz: Option, @@ -388,103 +400,61 @@ pub fn date_trunc(args: &[ColumnarValue]) -> Result { return exec_err!("Granularity of `date_trunc` must be non-null scalar Utf8"); }; + fn process_array( + array: &dyn Array, + granularity: String, + tz_opt: &Option>, + ) -> Result { + let parsed_tz = parse_tz(tz_opt)?; + let array = as_primitive_array::(array)?; + let array = array + .iter() + .map(|x| general_date_trunc(T::UNIT, &x, parsed_tz, granularity.as_str())) + .collect::>>()? + .with_timezone_opt(tz_opt.clone()); + Ok(ColumnarValue::Array(Arc::new(array))) + } + + fn process_scalar( + v: &Option, + granularity: String, + tz_opt: &Option>, + ) -> Result { + let parsed_tz = parse_tz(tz_opt)?; + let value = general_date_trunc(T::UNIT, v, parsed_tz, granularity.as_str())?; + let value = ScalarValue::new_timestamp::(value, tz_opt.clone()); + Ok(ColumnarValue::Scalar(value)) + } + Ok(match array { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Nanosecond, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampNanosecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Microsecond, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampMicrosecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Millisecond, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampMillisecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Second, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampSecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Array(array) => { let array_type = array.data_type(); match array_type { DataType::Timestamp(TimeUnit::Second, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_second_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Second, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()?; - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_millisecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Millisecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()?; - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_microsecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Microsecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()?; - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } - _ => { - let parsed_tz = None; - let array = as_timestamp_nanosecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Nanosecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()?; - - ColumnarValue::Array(Arc::new(array)) + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + process_array::(array, granularity, tz_opt)? } + _ => process_array::(array, granularity, &None)?, } } _ => { @@ -673,85 +643,104 @@ fn date_bin_impl( return exec_err!("DATE_BIN stride must be non-zero"); } - let f_nanos = |x: Option| x.map(|x| stride_fn(stride, x, origin)); - let f_micros = |x: Option| { - let scale = 1_000; - x.map(|x| stride_fn(stride, x * scale, origin) / scale) - }; - let f_millis = |x: Option| { - let scale = 1_000_000; - x.map(|x| stride_fn(stride, x * scale, origin) / scale) - }; - let f_secs = |x: Option| { - let scale = 1_000_000_000; - x.map(|x| stride_fn(stride, x * scale, origin) / scale) - }; + fn stride_map_fn( + origin: i64, + stride: i64, + stride_fn: fn(i64, i64, i64) -> i64, + ) -> impl Fn(Option) -> Option { + let scale = match T::UNIT { + TimeUnit::Nanosecond => 1, + TimeUnit::Microsecond => NANOSECONDS / 1_000_000, + TimeUnit::Millisecond => NANOSECONDS / 1_000, + TimeUnit::Second => NANOSECONDS, + }; + move |x: Option| x.map(|x| stride_fn(stride, x * scale, origin) / scale) + } Ok(match array { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - f_nanos(*v), + apply_stride_fn(*v), tz_opt.clone(), )) } ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - f_micros(*v), + apply_stride_fn(*v), tz_opt.clone(), )) } ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( - f_millis(*v), + apply_stride_fn(*v), tz_opt.clone(), )) } ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampSecond( - f_secs(*v), + apply_stride_fn(*v), tz_opt.clone(), )) } - ColumnarValue::Array(array) => match array.data_type() { - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let array = as_timestamp_nanosecond_array(array)? - .iter() - .map(f_nanos) - .collect::(); - ColumnarValue::Array(Arc::new(array)) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let array = as_timestamp_microsecond_array(array)? - .iter() - .map(f_micros) - .collect::(); - - ColumnarValue::Array(Arc::new(array)) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - let array = as_timestamp_millisecond_array(array)? - .iter() - .map(f_millis) - .collect::(); - - ColumnarValue::Array(Arc::new(array)) - } - DataType::Timestamp(TimeUnit::Second, _) => { - let array = as_timestamp_second_array(array)? + ColumnarValue::Array(array) => { + fn transform_array_with_stride( + origin: i64, + stride: i64, + stride_fn: fn(i64, i64, i64) -> i64, + array: &ArrayRef, + tz_opt: &Option>, + ) -> Result + where + T: ArrowTimestampType, + { + let array = as_primitive_array::(array)?; + let apply_stride_fn = stride_map_fn::(origin, stride, stride_fn); + let array = array .iter() - .map(f_secs) - .collect::(); + .map(apply_stride_fn) + .collect::>() + .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) + Ok(ColumnarValue::Array(Arc::new(array))) } - _ => { - return exec_err!( - "DATE_BIN expects source argument to be a TIMESTAMP but got {}", - array.data_type() - ) + match array.data_type() { + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + transform_array_with_stride::( + origin, stride, stride_fn, array, tz_opt, + )? + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + transform_array_with_stride::( + origin, stride, stride_fn, array, tz_opt, + )? + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + transform_array_with_stride::( + origin, stride, stride_fn, array, tz_opt, + )? + } + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + transform_array_with_stride::( + origin, stride, stride_fn, array, tz_opt, + )? + } + _ => { + return exec_err!( + "DATE_BIN expects source argument to be a TIMESTAMP but got {}", + array.data_type() + ) + } } - }, + } _ => { return exec_err!( "DATE_BIN expects source argument to be a TIMESTAMP scalar or array" @@ -817,7 +806,7 @@ pub fn date_part(args: &[ColumnarValue]) -> Result { let array = match array { ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array(), + ColumnarValue::Scalar(scalar) => scalar.to_array()?, }; let arr = match date_part.to_lowercase().as_str() { @@ -897,35 +886,198 @@ where T: ArrowTemporalType + ArrowNumericType, i64: From, { - let mut b = Float64Builder::with_capacity(array.len()); - match array.data_type() { + let b = match array.data_type() { DataType::Timestamp(tu, _) => { - for i in 0..array.len() { - if array.is_null(i) { - b.append_null(); - } else { - let scale = match tu { - TimeUnit::Second => 1, - TimeUnit::Millisecond => 1_000, - TimeUnit::Microsecond => 1_000_000, - TimeUnit::Nanosecond => 1_000_000_000, - }; - - let n: i64 = array.value(i).into(); - b.append_value(n as f64 / scale as f64); - } - } + let scale = match tu { + TimeUnit::Second => 1, + TimeUnit::Millisecond => 1_000, + TimeUnit::Microsecond => 1_000_000, + TimeUnit::Nanosecond => 1_000_000_000, + } as f64; + array.unary(|n| { + let n: i64 = n.into(); + n as f64 / scale + }) + } + DataType::Date32 => { + let seconds_in_a_day = 86400_f64; + array.unary(|n| { + let n: i64 = n.into(); + n as f64 * seconds_in_a_day + }) } + DataType::Date64 => array.unary(|n| { + let n: i64 = n.into(); + n as f64 / 1_000_f64 + }), _ => return internal_err!("Can not convert {:?} to epoch", array.data_type()), + }; + Ok(b) +} + +/// to_timestammp() SQL function implementation +pub fn to_timestamp_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp function requires 1 arguments, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 => cast_column( + &cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None)?, + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), + DataType::Float64 => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), + DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp", + other + ) + } + } +} + +/// to_timestamp_millis() SQL function implementation +pub fn to_timestamp_millis_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_millis function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Millisecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp_millis(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_millis", + other + ) + } + } +} + +/// to_timestamp_micros() SQL function implementation +pub fn to_timestamp_micros_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_micros function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Microsecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp_micros(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_micros", + other + ) + } + } +} + +/// to_timestamp_nanos() SQL function implementation +pub fn to_timestamp_nanos_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_nanos function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp_nanos(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_nanos", + other + ) + } + } +} + +/// to_timestamp_seconds() SQL function implementation +pub fn to_timestamp_seconds_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_seconds function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => { + cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None) + } + DataType::Utf8 => datetime_expressions::to_timestamp_seconds(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_seconds", + other + ) + } + } +} + +/// from_unixtime() SQL function implementation +pub fn from_unixtime_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "from_unixtime function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 => { + cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None) + } + other => { + internal_err!( + "Unsupported data type {:?} for function from_unixtime", + other + ) + } } - Ok(b.finish()) } #[cfg(test)] mod tests { use std::sync::Arc; - use arrow::array::{ArrayRef, Int64Array, IntervalDayTimeArray, StringBuilder}; + use arrow::array::{ + as_primitive_array, ArrayRef, Int64Array, IntervalDayTimeArray, StringBuilder, + }; + use arrow_array::TimestampNanosecondArray; use super::*; @@ -936,7 +1088,7 @@ mod tests { let mut string_builder = StringBuilder::with_capacity(2, 1024); let mut ts_builder = TimestampNanosecondArray::builder(2); - string_builder.append_value("2020-09-08T13:42:29.190855Z"); + string_builder.append_value("2020-09-08T13:42:29.190855"); ts_builder.append_value(1599572549190855000); string_builder.append_null(); @@ -1051,6 +1203,125 @@ mod tests { }); } + #[test] + fn test_date_trunc_timezones() { + let cases = vec![ + ( + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T01:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T03:00:00Z", + "2020-09-08T04:00:00Z", + ], + Some("+00".into()), + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + ], + ), + ( + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T01:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T03:00:00Z", + "2020-09-08T04:00:00Z", + ], + None, + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + ], + ), + ( + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T01:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T03:00:00Z", + "2020-09-08T04:00:00Z", + ], + Some("-02".into()), + vec![ + "2020-09-07T02:00:00Z", + "2020-09-07T02:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T02:00:00Z", + ], + ), + ( + vec![ + "2020-09-08T00:00:00+05", + "2020-09-08T01:00:00+05", + "2020-09-08T02:00:00+05", + "2020-09-08T03:00:00+05", + "2020-09-08T04:00:00+05", + ], + Some("+05".into()), + vec![ + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + ], + ), + ( + vec![ + "2020-09-08T00:00:00+08", + "2020-09-08T01:00:00+08", + "2020-09-08T02:00:00+08", + "2020-09-08T03:00:00+08", + "2020-09-08T04:00:00+08", + ], + Some("+08".into()), + vec![ + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + ], + ), + ]; + + cases.iter().for_each(|(original, tz_opt, expected)| { + let input = original + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::() + .with_timezone_opt(tz_opt.clone()); + let right = expected + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::() + .with_timezone_opt(tz_opt.clone()); + let result = date_trunc(&[ + ColumnarValue::Scalar(ScalarValue::from("day")), + ColumnarValue::Array(Arc::new(input)), + ]) + .unwrap(); + if let ColumnarValue::Array(result) = result { + assert_eq!( + result.data_type(), + &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()) + ); + let left = as_primitive_array::(&result); + assert_eq!(left, &right); + } else { + panic!("unexpected column type"); + } + }); + } + #[test] fn test_date_bin_single() { use chrono::Duration; @@ -1252,6 +1523,136 @@ mod tests { ); } + #[test] + fn test_date_bin_timezones() { + let cases = vec![ + ( + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T01:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T03:00:00Z", + "2020-09-08T04:00:00Z", + ], + Some("+00".into()), + "1970-01-01T00:00:00Z", + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + ], + ), + ( + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T01:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T03:00:00Z", + "2020-09-08T04:00:00Z", + ], + None, + "1970-01-01T00:00:00Z", + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + ], + ), + ( + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T01:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T03:00:00Z", + "2020-09-08T04:00:00Z", + ], + Some("-02".into()), + "1970-01-01T00:00:00Z", + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + ], + ), + ( + vec![ + "2020-09-08T00:00:00+05", + "2020-09-08T01:00:00+05", + "2020-09-08T02:00:00+05", + "2020-09-08T03:00:00+05", + "2020-09-08T04:00:00+05", + ], + Some("+05".into()), + "1970-01-01T00:00:00+05", + vec![ + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + ], + ), + ( + vec![ + "2020-09-08T00:00:00+08", + "2020-09-08T01:00:00+08", + "2020-09-08T02:00:00+08", + "2020-09-08T03:00:00+08", + "2020-09-08T04:00:00+08", + ], + Some("+08".into()), + "1970-01-01T00:00:00+08", + vec![ + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + ], + ), + ]; + + cases + .iter() + .for_each(|(original, tz_opt, origin, expected)| { + let input = original + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::() + .with_timezone_opt(tz_opt.clone()); + let right = expected + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::() + .with_timezone_opt(tz_opt.clone()); + let result = date_bin(&[ + ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), + ColumnarValue::Array(Arc::new(input)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(string_to_timestamp_nanos(origin).unwrap()), + tz_opt.clone(), + )), + ]) + .unwrap(); + if let ColumnarValue::Array(result) = result { + assert_eq!( + result.data_type(), + &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()) + ); + let left = as_primitive_array::(&result); + assert_eq!(left, &right); + } else { + panic!("unexpected column type"); + } + }); + } + #[test] fn to_timestamp_invalid_input_type() -> Result<()> { // pass the wrong type of input array to to_timestamp and test diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs deleted file mode 100644 index 369c139aa30b..000000000000 --- a/datafusion/physical-expr/src/equivalence.rs +++ /dev/null @@ -1,1134 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::expressions::{CastExpr, Column}; -use crate::utils::{collect_columns, merge_vectors}; -use crate::{ - LexOrdering, LexOrderingRef, LexOrderingReq, PhysicalExpr, PhysicalSortExpr, - PhysicalSortRequirement, -}; - -use arrow::datatypes::SchemaRef; -use arrow_schema::Fields; - -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::Result; -use itertools::izip; -use std::collections::{HashMap, HashSet}; -use std::hash::Hash; -use std::ops::Range; -use std::sync::Arc; - -/// Represents a collection of [`EquivalentClass`] (equivalences -/// between columns in relations) -/// -/// This is used to represent: -/// -/// 1. Equality conditions (like `A=B`), when `T` = [`Column`] -#[derive(Debug, Clone)] -pub struct EquivalenceProperties { - classes: Vec>, - schema: SchemaRef, -} - -impl EquivalenceProperties { - pub fn new(schema: SchemaRef) -> Self { - EquivalenceProperties { - classes: vec![], - schema, - } - } - - /// return the set of equivalences - pub fn classes(&self) -> &[EquivalentClass] { - &self.classes - } - - pub fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - /// Add the [`EquivalentClass`] from `iter` to this list - pub fn extend>>(&mut self, iter: I) { - for ec in iter { - self.classes.push(ec) - } - } - - /// Adds new equal conditions into the EquivalenceProperties. New equal - /// conditions usually come from equality predicates in a join/filter. - pub fn add_equal_conditions(&mut self, new_conditions: (&Column, &Column)) { - let mut idx1: Option = None; - let mut idx2: Option = None; - for (idx, class) in self.classes.iter_mut().enumerate() { - let contains_first = class.contains(new_conditions.0); - let contains_second = class.contains(new_conditions.1); - match (contains_first, contains_second) { - (true, false) => { - class.insert(new_conditions.1.clone()); - idx1 = Some(idx); - } - (false, true) => { - class.insert(new_conditions.0.clone()); - idx2 = Some(idx); - } - (true, true) => { - idx1 = Some(idx); - idx2 = Some(idx); - break; - } - (false, false) => {} - } - } - - match (idx1, idx2) { - (Some(idx_1), Some(idx_2)) if idx_1 != idx_2 => { - // need to merge the two existing EquivalentClasses - let second_eq_class = self.classes.get(idx_2).unwrap().clone(); - let first_eq_class = self.classes.get_mut(idx_1).unwrap(); - for prop in second_eq_class.iter() { - if !first_eq_class.contains(prop) { - first_eq_class.insert(prop.clone()); - } - } - self.classes.remove(idx_2); - } - (None, None) => { - // adding new pairs - self.classes.push(EquivalentClass::::new( - new_conditions.0.clone(), - vec![new_conditions.1.clone()], - )); - } - _ => {} - } - } - - /// Normalizes physical expression according to `EquivalentClass`es inside `self.classes`. - /// expression is replaced with `EquivalentClass::head` expression if it is among `EquivalentClass::others`. - pub fn normalize_expr(&self, expr: Arc) -> Arc { - expr.clone() - .transform(&|expr| { - let normalized_form = - expr.as_any().downcast_ref::().and_then(|column| { - for class in &self.classes { - if class.contains(column) { - return Some(Arc::new(class.head().clone()) as _); - } - } - None - }); - Ok(if let Some(normalized_form) = normalized_form { - Transformed::Yes(normalized_form) - } else { - Transformed::No(expr) - }) - }) - .unwrap_or(expr) - } - - /// This function applies the \[`normalize_expr`] - /// function for all expression in `exprs` and returns a vector of - /// normalized physical expressions. - pub fn normalize_exprs( - &self, - exprs: &[Arc], - ) -> Vec> { - exprs - .iter() - .map(|expr| self.normalize_expr(expr.clone())) - .collect::>() - } - - /// This function normalizes `sort_requirement` according to `EquivalenceClasses` in the `self`. - /// If the given sort requirement doesn't belong to equivalence set inside - /// `self`, it returns `sort_requirement` as is. - pub fn normalize_sort_requirement( - &self, - mut sort_requirement: PhysicalSortRequirement, - ) -> PhysicalSortRequirement { - sort_requirement.expr = self.normalize_expr(sort_requirement.expr); - sort_requirement - } - - /// This function applies the \[`normalize_sort_requirement`] - /// function for all sort requirements in `sort_reqs` and returns a vector of - /// normalized sort expressions. - pub fn normalize_sort_requirements( - &self, - sort_reqs: &[PhysicalSortRequirement], - ) -> Vec { - let normalized_sort_reqs = sort_reqs - .iter() - .map(|sort_req| self.normalize_sort_requirement(sort_req.clone())) - .collect::>(); - collapse_vec(normalized_sort_reqs) - } - - /// Similar to the \[`normalize_sort_requirements`] this function normalizes - /// sort expressions in `sort_exprs` and returns a vector of - /// normalized sort expressions. - pub fn normalize_sort_exprs( - &self, - sort_exprs: &[PhysicalSortExpr], - ) -> Vec { - let sort_requirements = - PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); - let normalized_sort_requirement = - self.normalize_sort_requirements(&sort_requirements); - PhysicalSortRequirement::to_sort_exprs(normalized_sort_requirement) - } -} - -/// `OrderingEquivalenceProperties` keeps track of columns that describe the -/// global ordering of the schema. These columns are not necessarily same; e.g. -/// ```text -/// ┌-------┐ -/// | a | b | -/// |---|---| -/// | 1 | 9 | -/// | 2 | 8 | -/// | 3 | 7 | -/// | 5 | 5 | -/// └---┴---┘ -/// ``` -/// where both `a ASC` and `b DESC` can describe the table ordering. With -/// `OrderingEquivalenceProperties`, we can keep track of these equivalences -/// and treat `a ASC` and `b DESC` as the same ordering requirement. -#[derive(Debug, Clone)] -pub struct OrderingEquivalenceProperties { - oeq_class: Option, - /// Keeps track of expressions that have constant value. - constants: Vec>, - schema: SchemaRef, -} - -impl OrderingEquivalenceProperties { - /// Create an empty `OrderingEquivalenceProperties` - pub fn new(schema: SchemaRef) -> Self { - Self { - oeq_class: None, - constants: vec![], - schema, - } - } - - /// Extends `OrderingEquivalenceProperties` by adding ordering inside the `other` - /// to the `self.oeq_class`. - pub fn extend(&mut self, other: Option) { - if let Some(other) = other { - if let Some(class) = &mut self.oeq_class { - class.others.insert(other.head); - class.others.extend(other.others); - } else { - self.oeq_class = Some(other); - } - } - } - - pub fn oeq_class(&self) -> Option<&OrderingEquivalentClass> { - self.oeq_class.as_ref() - } - - /// Adds new equal conditions into the EquivalenceProperties. New equal - /// conditions usually come from equality predicates in a join/filter. - pub fn add_equal_conditions(&mut self, new_conditions: (&LexOrdering, &LexOrdering)) { - if let Some(class) = &mut self.oeq_class { - class.insert(new_conditions.0.clone()); - class.insert(new_conditions.1.clone()); - } else { - let head = new_conditions.0.clone(); - let others = vec![new_conditions.1.clone()]; - self.oeq_class = Some(OrderingEquivalentClass::new(head, others)) - } - } - - /// Add physical expression that have constant value to the `self.constants` - pub fn with_constants(mut self, constants: Vec>) -> Self { - constants.into_iter().for_each(|constant| { - if !physical_exprs_contains(&self.constants, &constant) { - self.constants.push(constant); - } - }); - self - } - - pub fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - /// This function normalizes `sort_reqs` by - /// - removing expressions that have constant value from requirement - /// - replacing sections that are in the `self.oeq_class.others` with `self.oeq_class.head` - /// - removing sections that satisfies global ordering that are in the post fix of requirement - pub fn normalize_sort_requirements( - &self, - sort_reqs: &[PhysicalSortRequirement], - ) -> Vec { - let normalized_sort_reqs = - prune_sort_reqs_with_constants(sort_reqs, &self.constants); - let mut normalized_sort_reqs = collapse_lex_req(normalized_sort_reqs); - if let Some(oeq_class) = &self.oeq_class { - for item in oeq_class.others() { - let item = PhysicalSortRequirement::from_sort_exprs(item); - let item = prune_sort_reqs_with_constants(&item, &self.constants); - let ranges = get_compatible_ranges(&normalized_sort_reqs, &item); - let mut offset: i64 = 0; - for Range { start, end } in ranges { - let head = PhysicalSortRequirement::from_sort_exprs(oeq_class.head()); - let mut head = prune_sort_reqs_with_constants(&head, &self.constants); - let updated_start = (start as i64 + offset) as usize; - let updated_end = (end as i64 + offset) as usize; - let range = end - start; - offset += head.len() as i64 - range as i64; - let all_none = normalized_sort_reqs[updated_start..updated_end] - .iter() - .all(|req| req.options.is_none()); - if all_none { - for req in head.iter_mut() { - req.options = None; - } - } - normalized_sort_reqs.splice(updated_start..updated_end, head); - } - } - normalized_sort_reqs = simplify_lex_req(normalized_sort_reqs, oeq_class); - } - collapse_lex_req(normalized_sort_reqs) - } - - /// Checks whether `leading_ordering` is contained in any of the ordering - /// equivalence classes. - pub fn satisfies_leading_ordering( - &self, - leading_ordering: &PhysicalSortExpr, - ) -> bool { - if let Some(oeq_class) = &self.oeq_class { - for ordering in oeq_class - .others - .iter() - .chain(std::iter::once(&oeq_class.head)) - { - if ordering[0].eq(leading_ordering) { - return true; - } - } - } - false - } -} - -/// EquivalentClass is a set of [`Column`]s or [`PhysicalSortExpr`]s that are known -/// to have the same value in all tuples in a relation. `EquivalentClass` -/// is generated by equality predicates, typically equijoin conditions and equality -/// conditions in filters. `EquivalentClass` is generated by the -/// `ROW_NUMBER` window function. -#[derive(Debug, Clone)] -pub struct EquivalentClass { - /// First element in the EquivalentClass - head: T, - /// Other equal columns - others: HashSet, -} - -impl EquivalentClass { - pub fn new(head: T, others: Vec) -> EquivalentClass { - EquivalentClass { - head, - others: HashSet::from_iter(others), - } - } - - pub fn head(&self) -> &T { - &self.head - } - - pub fn others(&self) -> &HashSet { - &self.others - } - - pub fn contains(&self, col: &T) -> bool { - self.head == *col || self.others.contains(col) - } - - pub fn insert(&mut self, col: T) -> bool { - self.head != col && self.others.insert(col) - } - - pub fn remove(&mut self, col: &T) -> bool { - let removed = self.others.remove(col); - // If we are removing the head, adjust others so that its first entry becomes the new head. - if !removed && *col == self.head { - if let Some(col) = self.others.iter().next().cloned() { - let removed = self.others.remove(&col); - self.head = col; - removed - } else { - // We don't allow empty equivalence classes, reject removal if one tries removing - // the only element in an equivalence class. - false - } - } else { - removed - } - } - - pub fn iter(&self) -> impl Iterator { - std::iter::once(&self.head).chain(self.others.iter()) - } - - pub fn len(&self) -> usize { - self.others.len() + 1 - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } -} - -/// `LexOrdering` stores the lexicographical ordering for a schema. -/// OrderingEquivalentClass keeps track of different alternative orderings than can -/// describe the schema. -/// For instance, for the table below -/// |a|b|c|d| -/// |1|4|3|1| -/// |2|3|3|2| -/// |3|1|2|2| -/// |3|2|1|3| -/// both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the ordering of the table. -/// For this case, we say that `vec![a ASC, b ASC]`, and `vec![c DESC, d ASC]` are ordering equivalent. -pub type OrderingEquivalentClass = EquivalentClass; - -/// Update each expression in `ordering` with alias expressions. Assume -/// `ordering` is `a ASC, b ASC` and `c` is alias of `b`. Then, the result -/// will be `a ASC, c ASC`. -fn update_with_alias( - mut ordering: LexOrdering, - oeq_alias_map: &[(Column, Column)], -) -> LexOrdering { - for (source_col, target_col) in oeq_alias_map { - let source_col: Arc = Arc::new(source_col.clone()); - // Replace invalidated columns with its alias in the ordering expression. - let target_col: Arc = Arc::new(target_col.clone()); - for item in ordering.iter_mut() { - if item.expr.eq(&source_col) { - // Change the corresponding entry with alias expression - item.expr = target_col.clone(); - } - } - } - ordering -} - -impl OrderingEquivalentClass { - /// This function updates ordering equivalences with alias information. - /// For instance, assume columns `a` and `b` are aliases (a as b), and - /// orderings `a ASC` and `c DESC` are equivalent. Here, we replace column - /// `a` with `b` in ordering equivalence expressions. After this function, - /// `a ASC`, `c DESC` will be converted to the `b ASC`, `c DESC`. - fn update_with_aliases( - &mut self, - oeq_alias_map: &[(Column, Column)], - fields: &Fields, - ) { - let is_head_invalid = self.head.iter().any(|sort_expr| { - collect_columns(&sort_expr.expr) - .iter() - .any(|col| is_column_invalid_in_new_schema(col, fields)) - }); - // If head is invalidated, update head with alias expressions - if is_head_invalid { - self.head = update_with_alias(self.head.clone(), oeq_alias_map); - } else { - let new_oeq_expr = update_with_alias(self.head.clone(), oeq_alias_map); - self.insert(new_oeq_expr); - } - for ordering in self.others.clone().into_iter() { - self.insert(update_with_alias(ordering, oeq_alias_map)); - } - } - - /// Adds `offset` value to the index of each expression inside `self.head` and `self.others`. - pub fn add_offset(&self, offset: usize) -> Result { - let head = add_offset_to_lex_ordering(self.head(), offset)?; - let others = self - .others() - .iter() - .map(|ordering| add_offset_to_lex_ordering(ordering, offset)) - .collect::>>()?; - Ok(OrderingEquivalentClass::new(head, others)) - } - - /// This function normalizes `OrderingEquivalenceProperties` according to `eq_properties`. - /// More explicitly, it makes sure that expressions in `oeq_class` are head entries - /// in `eq_properties`, replacing any non-head entries with head entries if necessary. - pub fn normalize_with_equivalence_properties( - &self, - eq_properties: &EquivalenceProperties, - ) -> OrderingEquivalentClass { - let head = eq_properties.normalize_sort_exprs(self.head()); - - let others = self - .others() - .iter() - .map(|other| eq_properties.normalize_sort_exprs(other)) - .collect(); - - EquivalentClass::new(head, others) - } - - /// Prefix with existing ordering. - pub fn prefix_ordering_equivalent_class_with_existing_ordering( - &self, - existing_ordering: &[PhysicalSortExpr], - eq_properties: &EquivalenceProperties, - ) -> OrderingEquivalentClass { - let existing_ordering = eq_properties.normalize_sort_exprs(existing_ordering); - let normalized_head = eq_properties.normalize_sort_exprs(self.head()); - let updated_head = merge_vectors(&existing_ordering, &normalized_head); - let updated_others = self - .others() - .iter() - .map(|ordering| { - let normalized_ordering = eq_properties.normalize_sort_exprs(ordering); - merge_vectors(&existing_ordering, &normalized_ordering) - }) - .collect(); - OrderingEquivalentClass::new(updated_head, updated_others) - } -} - -/// This is a builder object facilitating incremental construction -/// for ordering equivalences. -pub struct OrderingEquivalenceBuilder { - eq_properties: EquivalenceProperties, - ordering_eq_properties: OrderingEquivalenceProperties, - existing_ordering: Vec, - schema: SchemaRef, -} - -impl OrderingEquivalenceBuilder { - pub fn new(schema: SchemaRef) -> Self { - let eq_properties = EquivalenceProperties::new(schema.clone()); - let ordering_eq_properties = OrderingEquivalenceProperties::new(schema.clone()); - Self { - eq_properties, - ordering_eq_properties, - existing_ordering: vec![], - schema, - } - } - - pub fn extend( - mut self, - new_ordering_eq_properties: OrderingEquivalenceProperties, - ) -> Self { - self.ordering_eq_properties - .extend(new_ordering_eq_properties.oeq_class().cloned()); - self - } - - pub fn with_existing_ordering( - mut self, - existing_ordering: Option>, - ) -> Self { - if let Some(existing_ordering) = existing_ordering { - self.existing_ordering = existing_ordering; - } - self - } - - pub fn with_equivalences(mut self, new_eq_properties: EquivalenceProperties) -> Self { - self.eq_properties = new_eq_properties; - self - } - - pub fn add_equal_conditions( - &mut self, - new_equivalent_ordering: Vec, - ) { - let mut normalized_out_ordering = vec![]; - for item in &self.existing_ordering { - // To account for ordering equivalences, first normalize the expression: - let normalized = self.eq_properties.normalize_expr(item.expr.clone()); - normalized_out_ordering.push(PhysicalSortExpr { - expr: normalized, - options: item.options, - }); - } - // If there is an existing ordering, add new ordering as an equivalence: - if !normalized_out_ordering.is_empty() { - self.ordering_eq_properties.add_equal_conditions(( - &normalized_out_ordering, - &new_equivalent_ordering, - )); - } - } - - /// Return a reference to the schema with which this builder was constructed with - pub fn schema(&self) -> &SchemaRef { - &self.schema - } - - /// Return a reference to the existing ordering - pub fn existing_ordering(&self) -> &LexOrdering { - &self.existing_ordering - } - - pub fn build(self) -> OrderingEquivalenceProperties { - self.ordering_eq_properties - } -} - -/// Checks whether column is still valid after projection. -fn is_column_invalid_in_new_schema(column: &Column, fields: &Fields) -> bool { - let idx = column.index(); - idx >= fields.len() || fields[idx].name() != column.name() -} - -/// Gets first aliased version of `col` found in `alias_map`. -fn get_alias_column( - col: &Column, - alias_map: &HashMap>, -) -> Option { - alias_map - .iter() - .find_map(|(column, columns)| column.eq(col).then(|| columns[0].clone())) -} - -/// This function applies the given projection to the given equivalence -/// properties to compute the resulting (projected) equivalence properties; e.g. -/// 1) Adding an alias, which can introduce additional equivalence properties, -/// as in Projection(a, a as a1, a as a2). -/// 2) Truncate the [`EquivalentClass`]es that are not in the output schema. -pub fn project_equivalence_properties( - input_eq: EquivalenceProperties, - alias_map: &HashMap>, - output_eq: &mut EquivalenceProperties, -) { - // Get schema and fields of projection output - let schema = output_eq.schema(); - let fields = schema.fields(); - - let mut eq_classes = input_eq.classes().to_vec(); - for (column, columns) in alias_map { - let mut find_match = false; - for class in eq_classes.iter_mut() { - // If `self.head` is invalidated in the new schema, update head - // with this change `self.head` is not randomly assigned by one of the entries from `self.others` - if is_column_invalid_in_new_schema(&class.head, fields) { - if let Some(alias_col) = get_alias_column(&class.head, alias_map) { - class.head = alias_col; - } - } - if class.contains(column) { - for col in columns { - class.insert(col.clone()); - } - find_match = true; - break; - } - } - if !find_match { - eq_classes.push(EquivalentClass::new(column.clone(), columns.clone())); - } - } - - // Prune columns that are no longer in the schema from equivalences. - for class in eq_classes.iter_mut() { - let columns_to_remove = class - .iter() - .filter(|column| is_column_invalid_in_new_schema(column, fields)) - .cloned() - .collect::>(); - for column in columns_to_remove { - class.remove(&column); - } - } - - eq_classes.retain(|props| { - props.len() > 1 - && - // A column should not give an equivalence with itself. - !(props.len() == 2 && props.head.eq(props.others().iter().next().unwrap())) - }); - - output_eq.extend(eq_classes); -} - -/// This function applies the given projection to the given ordering -/// equivalence properties to compute the resulting (projected) ordering -/// equivalence properties; e.g. -/// 1) Adding an alias, which can introduce additional ordering equivalence -/// properties, as in Projection(a, a as a1, a as a2) extends global ordering -/// of a to a1 and a2. -/// 2) Truncate the [`OrderingEquivalentClass`]es that are not in the output schema. -pub fn project_ordering_equivalence_properties( - input_eq: OrderingEquivalenceProperties, - columns_map: &HashMap>, - output_eq: &mut OrderingEquivalenceProperties, -) { - // Get schema and fields of projection output - let schema = output_eq.schema(); - let fields = schema.fields(); - - let oeq_class = input_eq.oeq_class(); - let mut oeq_class = if let Some(oeq_class) = oeq_class { - oeq_class.clone() - } else { - return; - }; - let mut oeq_alias_map = vec![]; - for (column, columns) in columns_map { - if is_column_invalid_in_new_schema(column, fields) { - oeq_alias_map.push((column.clone(), columns[0].clone())); - } - } - oeq_class.update_with_aliases(&oeq_alias_map, fields); - - // Prune columns that no longer is in the schema from from the OrderingEquivalenceProperties. - let sort_exprs_to_remove = oeq_class - .iter() - .filter(|sort_exprs| { - sort_exprs.iter().any(|sort_expr| { - let cols_in_expr = collect_columns(&sort_expr.expr); - // If any one of the columns, used in Expression is invalid, remove expression - // from ordering equivalences - cols_in_expr - .iter() - .any(|col| is_column_invalid_in_new_schema(col, fields)) - }) - }) - .cloned() - .collect::>(); - for sort_exprs in sort_exprs_to_remove { - oeq_class.remove(&sort_exprs); - } - if oeq_class.len() > 1 { - output_eq.extend(Some(oeq_class)); - } -} - -/// Update `ordering` if it contains cast expression with target column -/// after projection, if there is no cast expression among `ordering` expressions, -/// returns `None`. -fn update_with_cast_exprs( - cast_exprs: &[(CastExpr, Column)], - mut ordering: LexOrdering, -) -> Option { - let mut is_changed = false; - for sort_expr in ordering.iter_mut() { - for (cast_expr, target_col) in cast_exprs.iter() { - if sort_expr.expr.eq(cast_expr.expr()) { - sort_expr.expr = Arc::new(target_col.clone()) as _; - is_changed = true; - } - } - } - is_changed.then_some(ordering) -} - -/// Update cast expressions inside ordering equivalence -/// properties with its target column after projection -pub fn update_ordering_equivalence_with_cast( - cast_exprs: &[(CastExpr, Column)], - input_oeq: &mut OrderingEquivalenceProperties, -) { - if let Some(cls) = &mut input_oeq.oeq_class { - for ordering in - std::iter::once(cls.head().clone()).chain(cls.others().clone().into_iter()) - { - if let Some(updated_ordering) = update_with_cast_exprs(cast_exprs, ordering) { - cls.insert(updated_ordering); - } - } - } -} - -/// Retrieves the ordering equivalence properties for a given schema and output ordering. -pub fn ordering_equivalence_properties_helper( - schema: SchemaRef, - eq_orderings: &[LexOrdering], -) -> OrderingEquivalenceProperties { - let mut oep = OrderingEquivalenceProperties::new(schema); - let first_ordering = if let Some(first) = eq_orderings.first() { - first - } else { - // Return an empty OrderingEquivalenceProperties: - return oep; - }; - // First entry among eq_orderings is the head, skip it: - for ordering in eq_orderings.iter().skip(1) { - if !ordering.is_empty() { - oep.add_equal_conditions((first_ordering, ordering)) - } - } - oep -} - -/// This function constructs a duplicate-free vector by filtering out duplicate -/// entries inside the given vector `input`. -fn collapse_vec(input: Vec) -> Vec { - let mut output = vec![]; - for item in input { - if !output.contains(&item) { - output.push(item); - } - } - output -} - -/// This function constructs a duplicate-free `LexOrderingReq` by filtering out duplicate -/// entries that have same physical expression inside the given vector `input`. -/// `vec![a Some(Asc), a Some(Desc)]` is collapsed to the `vec![a Some(Asc)]`. Since -/// when same expression is already seen before, following expressions are redundant. -fn collapse_lex_req(input: LexOrderingReq) -> LexOrderingReq { - let mut output = vec![]; - for item in input { - if !lex_req_contains(&output, &item) { - output.push(item); - } - } - output -} - -/// Check whether `sort_req.expr` is among the expressions of `lex_req`. -fn lex_req_contains( - lex_req: &[PhysicalSortRequirement], - sort_req: &PhysicalSortRequirement, -) -> bool { - for constant in lex_req { - if constant.expr.eq(&sort_req.expr) { - return true; - } - } - false -} - -/// This function simplifies lexicographical ordering requirement -/// inside `input` by removing postfix lexicographical requirements -/// that satisfy global ordering (occurs inside the ordering equivalent class) -fn simplify_lex_req( - input: LexOrderingReq, - oeq_class: &OrderingEquivalentClass, -) -> LexOrderingReq { - let mut section = &input[..]; - loop { - let n_prune = prune_last_n_that_is_in_oeq(section, oeq_class); - // Cannot prune entries from the end of requirement - if n_prune == 0 { - break; - } - section = §ion[0..section.len() - n_prune]; - } - if section.is_empty() { - PhysicalSortRequirement::from_sort_exprs(oeq_class.head()) - } else { - section.to_vec() - } -} - -/// Determines how many entries from the end can be deleted. -/// Last n entry satisfies global ordering, hence having them -/// as postfix in the lexicographical requirement is unnecessary. -/// Assume requirement is [a ASC, b ASC, c ASC], also assume that -/// existing ordering is [c ASC, d ASC]. In this case, since [c ASC] -/// is satisfied by the existing ordering (e.g corresponding section is global ordering), -/// [c ASC] can be pruned from the requirement: [a ASC, b ASC, c ASC]. In this case, -/// this function will return 1, to indicate last element can be removed from the requirement -fn prune_last_n_that_is_in_oeq( - input: &[PhysicalSortRequirement], - oeq_class: &OrderingEquivalentClass, -) -> usize { - let input_len = input.len(); - for ordering in std::iter::once(oeq_class.head()).chain(oeq_class.others().iter()) { - let mut search_range = std::cmp::min(ordering.len(), input_len); - while search_range > 0 { - let req_section = &input[input_len - search_range..]; - // let given_section = &ordering[0..search_range]; - if req_satisfied(ordering, req_section) { - return search_range; - } else { - search_range -= 1; - } - } - } - 0 -} - -/// Checks whether given section satisfies req. -fn req_satisfied(given: LexOrderingRef, req: &[PhysicalSortRequirement]) -> bool { - for (given, req) in izip!(given.iter(), req.iter()) { - let PhysicalSortRequirement { expr, options } = req; - if let Some(options) = options { - if options != &given.options || !expr.eq(&given.expr) { - return false; - } - } else if !expr.eq(&given.expr) { - return false; - } - } - true -} - -/// This function searches for the slice `section` inside the slice `given`. -/// It returns each range where `section` is compatible with the corresponding -/// slice in `given`. -fn get_compatible_ranges( - given: &[PhysicalSortRequirement], - section: &[PhysicalSortRequirement], -) -> Vec> { - let n_section = section.len(); - let n_end = if given.len() >= n_section { - given.len() - n_section + 1 - } else { - 0 - }; - (0..n_end) - .filter_map(|idx| { - let end = idx + n_section; - given[idx..end] - .iter() - .zip(section) - .all(|(req, given)| given.compatible(req)) - .then_some(Range { start: idx, end }) - }) - .collect() -} - -/// It is similar to contains method of vector. -/// Finds whether `expr` is among `physical_exprs`. -pub fn physical_exprs_contains( - physical_exprs: &[Arc], - expr: &Arc, -) -> bool { - physical_exprs - .iter() - .any(|physical_expr| physical_expr.eq(expr)) -} - -/// Remove ordering requirements that have constant value -fn prune_sort_reqs_with_constants( - ordering: &[PhysicalSortRequirement], - constants: &[Arc], -) -> Vec { - ordering - .iter() - .filter(|&order| !physical_exprs_contains(constants, &order.expr)) - .cloned() - .collect() -} - -/// Adds the `offset` value to `Column` indices inside `expr`. This function is -/// generally used during the update of the right table schema in join operations. -pub(crate) fn add_offset_to_expr( - expr: Arc, - offset: usize, -) -> Result> { - expr.transform_down(&|e| match e.as_any().downcast_ref::() { - Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( - col.name(), - offset + col.index(), - )))), - None => Ok(Transformed::No(e)), - }) -} - -/// Adds the `offset` value to `Column` indices inside `sort_expr.expr`. -pub(crate) fn add_offset_to_sort_expr( - sort_expr: &PhysicalSortExpr, - offset: usize, -) -> Result { - Ok(PhysicalSortExpr { - expr: add_offset_to_expr(sort_expr.expr.clone(), offset)?, - options: sort_expr.options, - }) -} - -/// Adds the `offset` value to `Column` indices for each `sort_expr.expr` -/// inside `sort_exprs`. -pub fn add_offset_to_lex_ordering( - sort_exprs: LexOrderingRef, - offset: usize, -) -> Result { - sort_exprs - .iter() - .map(|sort_expr| add_offset_to_sort_expr(sort_expr, offset)) - .collect() -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::Column; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::Result; - - use arrow_schema::SortOptions; - use std::sync::Arc; - - fn convert_to_requirement( - in_data: &[(&Column, Option)], - ) -> Vec { - in_data - .iter() - .map(|(col, options)| { - PhysicalSortRequirement::new(Arc::new((*col).clone()) as _, *options) - }) - .collect::>() - } - - #[test] - fn add_equal_conditions_test() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - Field::new("x", DataType::Int64, true), - Field::new("y", DataType::Int64, true), - ])); - - let mut eq_properties = EquivalenceProperties::new(schema); - let new_condition = (&Column::new("a", 0), &Column::new("b", 1)); - eq_properties.add_equal_conditions(new_condition); - assert_eq!(eq_properties.classes().len(), 1); - - let new_condition = (&Column::new("b", 1), &Column::new("a", 0)); - eq_properties.add_equal_conditions(new_condition); - assert_eq!(eq_properties.classes().len(), 1); - assert_eq!(eq_properties.classes()[0].len(), 2); - assert!(eq_properties.classes()[0].contains(&Column::new("a", 0))); - assert!(eq_properties.classes()[0].contains(&Column::new("b", 1))); - - let new_condition = (&Column::new("b", 1), &Column::new("c", 2)); - eq_properties.add_equal_conditions(new_condition); - assert_eq!(eq_properties.classes().len(), 1); - assert_eq!(eq_properties.classes()[0].len(), 3); - assert!(eq_properties.classes()[0].contains(&Column::new("a", 0))); - assert!(eq_properties.classes()[0].contains(&Column::new("b", 1))); - assert!(eq_properties.classes()[0].contains(&Column::new("c", 2))); - - let new_condition = (&Column::new("x", 3), &Column::new("y", 4)); - eq_properties.add_equal_conditions(new_condition); - assert_eq!(eq_properties.classes().len(), 2); - - let new_condition = (&Column::new("x", 3), &Column::new("a", 0)); - eq_properties.add_equal_conditions(new_condition); - assert_eq!(eq_properties.classes().len(), 1); - assert_eq!(eq_properties.classes()[0].len(), 5); - assert!(eq_properties.classes()[0].contains(&Column::new("a", 0))); - assert!(eq_properties.classes()[0].contains(&Column::new("b", 1))); - assert!(eq_properties.classes()[0].contains(&Column::new("c", 2))); - assert!(eq_properties.classes()[0].contains(&Column::new("x", 3))); - assert!(eq_properties.classes()[0].contains(&Column::new("y", 4))); - - Ok(()) - } - - #[test] - fn project_equivalence_properties_test() -> Result<()> { - let input_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - ])); - - let mut input_properties = EquivalenceProperties::new(input_schema); - let new_condition = (&Column::new("a", 0), &Column::new("b", 1)); - input_properties.add_equal_conditions(new_condition); - let new_condition = (&Column::new("b", 1), &Column::new("c", 2)); - input_properties.add_equal_conditions(new_condition); - - let out_schema = Arc::new(Schema::new(vec![ - Field::new("a1", DataType::Int64, true), - Field::new("a2", DataType::Int64, true), - Field::new("a3", DataType::Int64, true), - Field::new("a4", DataType::Int64, true), - ])); - - let mut alias_map = HashMap::new(); - alias_map.insert( - Column::new("a", 0), - vec![ - Column::new("a1", 0), - Column::new("a2", 1), - Column::new("a3", 2), - Column::new("a4", 3), - ], - ); - let mut out_properties = EquivalenceProperties::new(out_schema); - - project_equivalence_properties(input_properties, &alias_map, &mut out_properties); - assert_eq!(out_properties.classes().len(), 1); - assert_eq!(out_properties.classes()[0].len(), 4); - assert!(out_properties.classes()[0].contains(&Column::new("a1", 0))); - assert!(out_properties.classes()[0].contains(&Column::new("a2", 1))); - assert!(out_properties.classes()[0].contains(&Column::new("a3", 2))); - assert!(out_properties.classes()[0].contains(&Column::new("a4", 3))); - - Ok(()) - } - - #[test] - fn test_collapse_vec() -> Result<()> { - assert_eq!(collapse_vec(vec![1, 2, 3]), vec![1, 2, 3]); - assert_eq!(collapse_vec(vec![1, 2, 3, 2, 3]), vec![1, 2, 3]); - assert_eq!(collapse_vec(vec![3, 1, 2, 3, 2, 3]), vec![3, 1, 2]); - Ok(()) - } - - #[test] - fn test_get_compatible_ranges() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let option1 = SortOptions { - descending: false, - nulls_first: false, - }; - let test_data = vec![ - ( - vec![(col_a, Some(option1)), (col_b, Some(option1))], - vec![(col_a, Some(option1))], - vec![(0, 1)], - ), - ( - vec![(col_a, None), (col_b, Some(option1))], - vec![(col_a, Some(option1))], - vec![(0, 1)], - ), - ( - vec![ - (col_a, None), - (col_b, Some(option1)), - (col_a, Some(option1)), - ], - vec![(col_a, Some(option1))], - vec![(0, 1), (2, 3)], - ), - ]; - for (searched, to_search, expected) in test_data { - let searched = convert_to_requirement(&searched); - let to_search = convert_to_requirement(&to_search); - let expected = expected - .into_iter() - .map(|(start, end)| Range { start, end }) - .collect::>(); - assert_eq!(get_compatible_ranges(&searched, &to_search), expected); - } - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs new file mode 100644 index 000000000000..f0bd1740d5d2 --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -0,0 +1,598 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping}; +use crate::{ + expressions::Column, physical_expr::deduplicate_physical_exprs, + physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexOrderingRef, + LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, +}; +use datafusion_common::tree_node::TreeNode; +use datafusion_common::{tree_node::Transformed, JoinType}; +use std::sync::Arc; + +/// An `EquivalenceClass` is a set of [`Arc`]s that are known +/// to have the same value for all tuples in a relation. These are generated by +/// equality predicates (e.g. `a = b`), typically equi-join conditions and +/// equality conditions in filters. +/// +/// Two `EquivalenceClass`es are equal if they contains the same expressions in +/// without any ordering. +#[derive(Debug, Clone)] +pub struct EquivalenceClass { + /// The expressions in this equivalence class. The order doesn't + /// matter for equivalence purposes + /// + /// TODO: use a HashSet for this instead of a Vec + exprs: Vec>, +} + +impl PartialEq for EquivalenceClass { + /// Returns true if other is equal in the sense + /// of bags (multi-sets), disregarding their orderings. + fn eq(&self, other: &Self) -> bool { + physical_exprs_bag_equal(&self.exprs, &other.exprs) + } +} + +impl EquivalenceClass { + /// Create a new empty equivalence class + pub fn new_empty() -> Self { + Self { exprs: vec![] } + } + + // Create a new equivalence class from a pre-existing `Vec` + pub fn new(mut exprs: Vec>) -> Self { + deduplicate_physical_exprs(&mut exprs); + Self { exprs } + } + + /// Return the inner vector of expressions + pub fn into_vec(self) -> Vec> { + self.exprs + } + + /// Return the "canonical" expression for this class (the first element) + /// if any + fn canonical_expr(&self) -> Option> { + self.exprs.first().cloned() + } + + /// Insert the expression into this class, meaning it is known to be equal to + /// all other expressions in this class + pub fn push(&mut self, expr: Arc) { + if !self.contains(&expr) { + self.exprs.push(expr); + } + } + + /// Inserts all the expressions from other into this class + pub fn extend(&mut self, other: Self) { + for expr in other.exprs { + // use push so entries are deduplicated + self.push(expr); + } + } + + /// Returns true if this equivalence class contains t expression + pub fn contains(&self, expr: &Arc) -> bool { + physical_exprs_contains(&self.exprs, expr) + } + + /// Returns true if this equivalence class has any entries in common with `other` + pub fn contains_any(&self, other: &Self) -> bool { + self.exprs.iter().any(|e| other.contains(e)) + } + + /// return the number of items in this class + pub fn len(&self) -> usize { + self.exprs.len() + } + + /// return true if this class is empty + pub fn is_empty(&self) -> bool { + self.exprs.is_empty() + } + + /// Iterate over all elements in this class, in some arbitrary order + pub fn iter(&self) -> impl Iterator> { + self.exprs.iter() + } + + /// Return a new equivalence class that have the specified offset added to + /// each expression (used when schemas are appended such as in joins) + pub fn with_offset(&self, offset: usize) -> Self { + let new_exprs = self + .exprs + .iter() + .cloned() + .map(|e| add_offset_to_expr(e, offset)) + .collect(); + Self::new(new_exprs) + } +} + +/// An `EquivalenceGroup` is a collection of `EquivalenceClass`es where each +/// class represents a distinct equivalence class in a relation. +#[derive(Debug, Clone)] +pub struct EquivalenceGroup { + pub classes: Vec, +} + +impl EquivalenceGroup { + /// Creates an empty equivalence group. + pub fn empty() -> Self { + Self { classes: vec![] } + } + + /// Creates an equivalence group from the given equivalence classes. + pub fn new(classes: Vec) -> Self { + let mut result = Self { classes }; + result.remove_redundant_entries(); + result + } + + /// Returns how many equivalence classes there are in this group. + pub fn len(&self) -> usize { + self.classes.len() + } + + /// Checks whether this equivalence group is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns an iterator over the equivalence classes in this group. + pub fn iter(&self) -> impl Iterator { + self.classes.iter() + } + + /// Adds the equality `left` = `right` to this equivalence group. + /// New equality conditions often arise after steps like `Filter(a = b)`, + /// `Alias(a, a as b)` etc. + pub fn add_equal_conditions( + &mut self, + left: &Arc, + right: &Arc, + ) { + let mut first_class = None; + let mut second_class = None; + for (idx, cls) in self.classes.iter().enumerate() { + if cls.contains(left) { + first_class = Some(idx); + } + if cls.contains(right) { + second_class = Some(idx); + } + } + match (first_class, second_class) { + (Some(mut first_idx), Some(mut second_idx)) => { + // If the given left and right sides belong to different classes, + // we should unify/bridge these classes. + if first_idx != second_idx { + // By convention, make sure `second_idx` is larger than `first_idx`. + if first_idx > second_idx { + (first_idx, second_idx) = (second_idx, first_idx); + } + // Remove the class at `second_idx` and merge its values with + // the class at `first_idx`. The convention above makes sure + // that `first_idx` is still valid after removing `second_idx`. + let other_class = self.classes.swap_remove(second_idx); + self.classes[first_idx].extend(other_class); + } + } + (Some(group_idx), None) => { + // Right side is new, extend left side's class: + self.classes[group_idx].push(right.clone()); + } + (None, Some(group_idx)) => { + // Left side is new, extend right side's class: + self.classes[group_idx].push(left.clone()); + } + (None, None) => { + // None of the expressions is among existing classes. + // Create a new equivalence class and extend the group. + self.classes + .push(EquivalenceClass::new(vec![left.clone(), right.clone()])); + } + } + } + + /// Removes redundant entries from this group. + fn remove_redundant_entries(&mut self) { + // Remove duplicate entries from each equivalence class: + self.classes.retain_mut(|cls| { + // Keep groups that have at least two entries as singleton class is + // meaningless (i.e. it contains no non-trivial information): + cls.len() > 1 + }); + // Unify/bridge groups that have common expressions: + self.bridge_classes() + } + + /// This utility function unifies/bridges classes that have common expressions. + /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`. + /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all + /// equal and belong to one class. This utility converts merges such classes. + fn bridge_classes(&mut self) { + let mut idx = 0; + while idx < self.classes.len() { + let mut next_idx = idx + 1; + let start_size = self.classes[idx].len(); + while next_idx < self.classes.len() { + if self.classes[idx].contains_any(&self.classes[next_idx]) { + let extension = self.classes.swap_remove(next_idx); + self.classes[idx].extend(extension); + } else { + next_idx += 1; + } + } + if self.classes[idx].len() > start_size { + continue; + } + idx += 1; + } + } + + /// Extends this equivalence group with the `other` equivalence group. + pub fn extend(&mut self, other: Self) { + self.classes.extend(other.classes); + self.remove_redundant_entries(); + } + + /// Normalizes the given physical expression according to this group. + /// The expression is replaced with the first expression in the equivalence + /// class it matches with (if any). + pub fn normalize_expr(&self, expr: Arc) -> Arc { + expr.clone() + .transform(&|expr| { + for cls in self.iter() { + if cls.contains(&expr) { + return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); + } + } + Ok(Transformed::No(expr)) + }) + .unwrap_or(expr) + } + + /// Normalizes the given sort expression according to this group. + /// The underlying physical expression is replaced with the first expression + /// in the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, returns + /// the sort expression as is. + pub fn normalize_sort_expr( + &self, + mut sort_expr: PhysicalSortExpr, + ) -> PhysicalSortExpr { + sort_expr.expr = self.normalize_expr(sort_expr.expr); + sort_expr + } + + /// Normalizes the given sort requirement according to this group. + /// The underlying physical expression is replaced with the first expression + /// in the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, returns + /// the given sort requirement as is. + pub fn normalize_sort_requirement( + &self, + mut sort_requirement: PhysicalSortRequirement, + ) -> PhysicalSortRequirement { + sort_requirement.expr = self.normalize_expr(sort_requirement.expr); + sort_requirement + } + + /// This function applies the `normalize_expr` function for all expressions + /// in `exprs` and returns the corresponding normalized physical expressions. + pub fn normalize_exprs( + &self, + exprs: impl IntoIterator>, + ) -> Vec> { + exprs + .into_iter() + .map(|expr| self.normalize_expr(expr)) + .collect() + } + + /// This function applies the `normalize_sort_expr` function for all sort + /// expressions in `sort_exprs` and returns the corresponding normalized + /// sort expressions. + pub fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { + // Convert sort expressions to sort requirements: + let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); + // Normalize the requirements: + let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); + // Convert sort requirements back to sort expressions: + PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) + } + + /// This function applies the `normalize_sort_requirement` function for all + /// requirements in `sort_reqs` and returns the corresponding normalized + /// sort requirements. + pub fn normalize_sort_requirements( + &self, + sort_reqs: LexRequirementRef, + ) -> LexRequirement { + collapse_lex_req( + sort_reqs + .iter() + .map(|sort_req| self.normalize_sort_requirement(sort_req.clone())) + .collect(), + ) + } + + /// Projects `expr` according to the given projection mapping. + /// If the resulting expression is invalid after projection, returns `None`. + pub fn project_expr( + &self, + mapping: &ProjectionMapping, + expr: &Arc, + ) -> Option> { + // First, we try to project expressions with an exact match. If we are + // unable to do this, we consult equivalence classes. + if let Some(target) = mapping.target_expr(expr) { + // If we match the source, we can project directly: + return Some(target); + } else { + // If the given expression is not inside the mapping, try to project + // expressions considering the equivalence classes. + for (source, target) in mapping.iter() { + // If we match an equivalent expression to `source`, then we can + // project. For example, if we have the mapping `(a as a1, a + c)` + // and the equivalence class `(a, b)`, expression `b` projects to `a1`. + if self + .get_equivalence_class(source) + .map_or(false, |group| group.contains(expr)) + { + return Some(target.clone()); + } + } + } + // Project a non-leaf expression by projecting its children. + let children = expr.children(); + if children.is_empty() { + // Leaf expression should be inside mapping. + return None; + } + children + .into_iter() + .map(|child| self.project_expr(mapping, &child)) + .collect::>>() + .map(|children| expr.clone().with_new_children(children).unwrap()) + } + + /// Projects this equivalence group according to the given projection mapping. + pub fn project(&self, mapping: &ProjectionMapping) -> Self { + let projected_classes = self.iter().filter_map(|cls| { + let new_class = cls + .iter() + .filter_map(|expr| self.project_expr(mapping, expr)) + .collect::>(); + (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) + }); + // TODO: Convert the algorithm below to a version that uses `HashMap`. + // once `Arc` can be stored in `HashMap`. + // See issue: https://github.com/apache/arrow-datafusion/issues/8027 + let mut new_classes = vec![]; + for (source, target) in mapping.iter() { + if new_classes.is_empty() { + new_classes.push((source, vec![target.clone()])); + } + if let Some((_, values)) = + new_classes.iter_mut().find(|(key, _)| key.eq(source)) + { + if !physical_exprs_contains(values, target) { + values.push(target.clone()); + } + } + } + // Only add equivalence classes with at least two members as singleton + // equivalence classes are meaningless. + let new_classes = new_classes + .into_iter() + .filter_map(|(_, values)| (values.len() > 1).then_some(values)) + .map(EquivalenceClass::new); + + let classes = projected_classes.chain(new_classes).collect(); + Self::new(classes) + } + + /// Returns the equivalence class containing `expr`. If no equivalence class + /// contains `expr`, returns `None`. + fn get_equivalence_class( + &self, + expr: &Arc, + ) -> Option<&EquivalenceClass> { + self.iter().find(|cls| cls.contains(expr)) + } + + /// Combine equivalence groups of the given join children. + pub fn join( + &self, + right_equivalences: &Self, + join_type: &JoinType, + left_size: usize, + on: &[(Column, Column)], + ) -> Self { + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + let mut result = Self::new( + self.iter() + .cloned() + .chain( + right_equivalences + .iter() + .map(|cls| cls.with_offset(left_size)), + ) + .collect(), + ); + // In we have an inner join, expressions in the "on" condition + // are equal in the resulting table. + if join_type == &JoinType::Inner { + for (lhs, rhs) in on.iter() { + let index = rhs.index() + left_size; + let new_lhs = Arc::new(lhs.clone()) as _; + let new_rhs = Arc::new(Column::new(rhs.name(), index)) as _; + result.add_equal_conditions(&new_lhs, &new_rhs); + } + } + result + } + JoinType::LeftSemi | JoinType::LeftAnti => self.clone(), + JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::equivalence::tests::create_test_params; + use crate::equivalence::{EquivalenceClass, EquivalenceGroup}; + use crate::expressions::lit; + use crate::expressions::Column; + use crate::expressions::Literal; + use datafusion_common::Result; + use datafusion_common::ScalarValue; + use std::sync::Arc; + + #[test] + fn test_bridge_groups() -> Result<()> { + // First entry in the tuple is argument, second entry is the bridged result + let test_cases = vec![ + // ------- TEST CASE 1 -----------// + ( + vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]], + // Expected is compared with set equality. Order of the specific results may change. + vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]], + ), + // ------- TEST CASE 2 -----------// + ( + vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]], + // Expected + vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]], + ), + ]; + for (entries, expected) in test_cases { + let entries = entries + .into_iter() + .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(EquivalenceClass::new) + .collect::>(); + let expected = expected + .into_iter() + .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(EquivalenceClass::new) + .collect::>(); + let mut eq_groups = EquivalenceGroup::new(entries.clone()); + eq_groups.bridge_classes(); + let eq_groups = eq_groups.classes; + let err_msg = format!( + "error in test entries: {:?}, expected: {:?}, actual:{:?}", + entries, expected, eq_groups + ); + assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg); + for idx in 0..eq_groups.len() { + assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg); + } + } + Ok(()) + } + + #[test] + fn test_remove_redundant_entries_eq_group() -> Result<()> { + let entries = vec![ + EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), + // This group is meaningless should be removed + EquivalenceClass::new(vec![lit(3), lit(3)]), + EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + ]; + // Given equivalences classes are not in succinct form. + // Expected form is the most plain representation that is functionally same. + let expected = vec![ + EquivalenceClass::new(vec![lit(1), lit(2)]), + EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + ]; + let mut eq_groups = EquivalenceGroup::new(entries); + eq_groups.remove_redundant_entries(); + + let eq_groups = eq_groups.classes; + assert_eq!(eq_groups.len(), expected.len()); + assert_eq!(eq_groups.len(), 2); + + assert_eq!(eq_groups[0], expected[0]); + assert_eq!(eq_groups[1], expected[1]); + Ok(()) + } + + #[test] + fn test_schema_normalize_expr_with_equivalence() -> Result<()> { + let col_a = &Column::new("a", 0); + let col_b = &Column::new("b", 1); + let col_c = &Column::new("c", 2); + // Assume that column a and c are aliases. + let (_test_schema, eq_properties) = create_test_params()?; + + let col_a_expr = Arc::new(col_a.clone()) as Arc; + let col_b_expr = Arc::new(col_b.clone()) as Arc; + let col_c_expr = Arc::new(col_c.clone()) as Arc; + // Test cases for equivalence normalization, + // First entry in the tuple is argument, second entry is expected result after normalization. + let expressions = vec![ + // Normalized version of the column a and c should go to a + // (by convention all the expressions inside equivalence class are mapped to the first entry + // in this case a is the first entry in the equivalence class.) + (&col_a_expr, &col_a_expr), + (&col_c_expr, &col_a_expr), + // Cannot normalize column b + (&col_b_expr, &col_b_expr), + ]; + let eq_group = eq_properties.eq_group(); + for (expr, expected_eq) in expressions { + assert!( + expected_eq.eq(&eq_group.normalize_expr(expr.clone())), + "error in test: expr: {expr:?}" + ); + } + + Ok(()) + } + + #[test] + fn test_contains_any() { + let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + as Arc; + let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + as Arc; + let lit2 = + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let lit1 = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + + let cls1 = EquivalenceClass::new(vec![lit_true.clone(), lit_false.clone()]); + let cls2 = EquivalenceClass::new(vec![lit_true.clone(), col_b_expr.clone()]); + let cls3 = EquivalenceClass::new(vec![lit2.clone(), lit1.clone()]); + + // lit_true is common + assert!(cls1.contains_any(&cls2)); + // there is no common entry + assert!(!cls1.contains_any(&cls3)); + assert!(!cls2.contains_any(&cls3)); + } +} diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs new file mode 100644 index 000000000000..387dce2cdc8b --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -0,0 +1,533 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod class; +mod ordering; +mod projection; +mod properties; +use crate::expressions::Column; +use crate::{LexRequirement, PhysicalExpr, PhysicalSortRequirement}; +pub use class::{EquivalenceClass, EquivalenceGroup}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +pub use ordering::OrderingEquivalenceClass; +pub use projection::ProjectionMapping; +pub use properties::{join_equivalence_properties, EquivalenceProperties}; +use std::sync::Arc; + +/// This function constructs a duplicate-free `LexOrderingReq` by filtering out +/// duplicate entries that have same physical expression inside. For example, +/// `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a Some(ASC)]`. +pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement { + let mut output = Vec::::new(); + for item in input { + if !output.iter().any(|req| req.expr.eq(&item.expr)) { + output.push(item); + } + } + output +} + +/// Adds the `offset` value to `Column` indices inside `expr`. This function is +/// generally used during the update of the right table schema in join operations. +pub fn add_offset_to_expr( + expr: Arc, + offset: usize, +) -> Arc { + expr.transform_down(&|e| match e.as_any().downcast_ref::() { + Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( + col.name(), + offset + col.index(), + )))), + None => Ok(Transformed::No(e)), + }) + .unwrap() + // Note that we can safely unwrap here since our transform always returns + // an `Ok` value. +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{col, Column}; + use crate::PhysicalSortExpr; + use arrow::compute::{lexsort_to_indices, SortColumn}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; + use arrow_schema::{SchemaRef, SortOptions}; + use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; + use itertools::izip; + use rand::rngs::StdRng; + use rand::seq::SliceRandom; + use rand::{Rng, SeedableRng}; + use std::sync::Arc; + + pub fn output_schema( + mapping: &ProjectionMapping, + input_schema: &Arc, + ) -> Result { + // Calculate output schema + let fields: Result> = mapping + .iter() + .map(|(source, target)| { + let name = target + .as_any() + .downcast_ref::() + .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? + .name(); + let field = Field::new( + name, + source.data_type(input_schema)?, + source.nullable(input_schema)?, + ); + + Ok(field) + }) + .collect(); + + let output_schema = Arc::new(Schema::new_with_metadata( + fields?, + input_schema.metadata().clone(), + )); + + Ok(output_schema) + } + + // Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h) + pub fn create_test_schema() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let g = Field::new("g", DataType::Int32, true); + let h = Field::new("h", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h])); + + Ok(schema) + } + + /// Construct a schema with following properties + /// Schema satisfies following orderings: + /// [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + /// and + /// Column [a=c] (e.g they are aliases). + pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + eq_properties.add_equal_conditions(col_a, col_c); + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let orderings = vec![ + // [a ASC] + vec![(col_a, option_asc)], + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [e DESC, f ASC, g ASC] + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + ]; + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + Ok((test_schema, eq_properties)) + } + + // Generate a schema which consists of 6 columns (a, b, c, d, e, f) + fn create_test_schema_2() -> Result { + let a = Field::new("a", DataType::Float64, true); + let b = Field::new("b", DataType::Float64, true); + let c = Field::new("c", DataType::Float64, true); + let d = Field::new("d", DataType::Float64, true); + let e = Field::new("e", DataType::Float64, true); + let f = Field::new("f", DataType::Float64, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) + } + + /// Construct a schema with random ordering + /// among column a, b, c, d + /// where + /// Column [a=f] (e.g they are aliases). + /// Column e is constant. + pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema_2()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; + + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + // Define a and f are aliases + eq_properties.add_equal_conditions(col_a, col_f); + // Column e has constant value. + eq_properties = eq_properties.add_constants([col_e.clone()]); + + // Randomly order columns for sorting + let mut rng = StdRng::seed_from_u64(seed); + let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted + + let options_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + while !remaining_exprs.is_empty() { + let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + remaining_exprs.shuffle(&mut rng); + + let ordering = remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: options_asc, + }) + .collect(); + + eq_properties.add_new_orderings([ordering]); + } + + Ok((test_schema, eq_properties)) + } + + // Convert each tuple to PhysicalSortRequirement + pub fn convert_to_sort_reqs( + in_data: &[(&Arc, Option)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| { + PhysicalSortRequirement::new((*expr).clone(), *options) + }) + .collect() + } + + // Convert each tuple to PhysicalSortExpr + pub fn convert_to_sort_exprs( + in_data: &[(&Arc, SortOptions)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: (*expr).clone(), + options: *options, + }) + .collect() + } + + // Convert each inner tuple to PhysicalSortExpr + pub fn convert_to_orderings( + orderings: &[Vec<(&Arc, SortOptions)>], + ) -> Vec> { + orderings + .iter() + .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) + .collect() + } + + // Convert each tuple to PhysicalSortExpr + pub fn convert_to_sort_exprs_owned( + in_data: &[(Arc, SortOptions)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: (*expr).clone(), + options: *options, + }) + .collect() + } + + // Convert each inner tuple to PhysicalSortExpr + pub fn convert_to_orderings_owned( + orderings: &[Vec<(Arc, SortOptions)>], + ) -> Vec> { + orderings + .iter() + .map(|sort_exprs| convert_to_sort_exprs_owned(sort_exprs)) + .collect() + } + + // Apply projection to the input_data, return projected equivalence properties and record batch + pub fn apply_projection( + proj_exprs: Vec<(Arc, String)>, + input_data: &RecordBatch, + input_eq_properties: &EquivalenceProperties, + ) -> Result<(RecordBatch, EquivalenceProperties)> { + let input_schema = input_data.schema(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + let output_schema = output_schema(&projection_mapping, &input_schema)?; + let num_rows = input_data.num_rows(); + // Apply projection to the input record batch. + let projected_values = projection_mapping + .iter() + .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) + .collect::>>()?; + let projected_batch = if projected_values.is_empty() { + RecordBatch::new_empty(output_schema.clone()) + } else { + RecordBatch::try_new(output_schema.clone(), projected_values)? + }; + + let projected_eq = + input_eq_properties.project(&projection_mapping, output_schema); + Ok((projected_batch, projected_eq)) + } + + #[test] + fn add_equal_conditions_test() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("x", DataType::Int64, true), + Field::new("y", DataType::Int64, true), + ])); + + let mut eq_properties = EquivalenceProperties::new(schema); + let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; + let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; + let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + + // a and b are aliases + eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + + // This new entry is redundant, size shouldn't increase + eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 2); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + + // b and c are aliases. Exising equivalence class should expand, + // however there shouldn't be any new equivalence class + eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 3); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + + // This is a new set of equality. Hence equivalent class count should be 2. + eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr); + assert_eq!(eq_properties.eq_group().len(), 2); + + // This equality bridges distinct equality sets. + // Hence equivalent class count should decrease from 2 to 1. + eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 5); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_x_expr)); + assert!(eq_groups.contains(&col_y_expr)); + + Ok(()) + } + + /// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. + /// + /// The function works by adding a unique column of ascending integers to the original table. This column ensures + /// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can + /// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce + /// deterministic sorting results. + /// + /// If the table remains the same after sorting with the added unique column, it indicates that the table was + /// already sorted according to `required_ordering` to begin with. + pub fn is_table_same_after_sort( + mut required_ordering: Vec, + batch: RecordBatch, + ) -> Result { + // Clone the original schema and columns + let original_schema = batch.schema(); + let mut columns = batch.columns().to_vec(); + + // Create a new unique column + let n_row = batch.num_rows(); + let vals: Vec = (0..n_row).collect::>(); + let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); + let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; + columns.push(unique_col.clone()); + + // Create a new schema with the added unique column + let unique_col_name = "unique"; + let unique_field = + Arc::new(Field::new(unique_col_name, DataType::Float64, false)); + let fields: Vec<_> = original_schema + .fields() + .iter() + .cloned() + .chain(std::iter::once(unique_field)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + + // Create a new batch with the added column + let new_batch = RecordBatch::try_new(schema.clone(), columns)?; + + // Add the unique column to the required ordering to ensure deterministic results + required_ordering.push(PhysicalSortExpr { + expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), + options: Default::default(), + }); + + // Convert the required ordering to a list of SortColumn + let sort_columns = required_ordering + .iter() + .map(|order_expr| { + let expr_result = order_expr.expr.evaluate(&new_batch)?; + let values = expr_result.into_array(new_batch.num_rows())?; + Ok(SortColumn { + values, + options: Some(order_expr.options), + }) + }) + .collect::>>()?; + + // Check if the indices after sorting match the initial ordering + let sorted_indices = lexsort_to_indices(&sort_columns, None)?; + let original_indices = UInt32Array::from_iter_values(0..n_row as u32); + + Ok(sorted_indices == original_indices) + } + + // If we already generated a random result for one of the + // expressions in the equivalence classes. For other expressions in the same + // equivalence class use same result. This util gets already calculated result, when available. + fn get_representative_arr( + eq_group: &EquivalenceClass, + existing_vec: &[Option], + schema: SchemaRef, + ) -> Option { + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + if let Some(res) = &existing_vec[idx] { + return Some(res.clone()); + } + } + None + } + + // Generate a table that satisfies the given equivalence properties; i.e. + // equivalences, ordering equivalences, and constants. + pub fn generate_table_for_eq_properties( + eq_properties: &EquivalenceProperties, + n_elem: usize, + n_distinct: usize, + ) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + let schema = eq_properties.schema(); + let mut schema_vec = vec![None; schema.fields.len()]; + + // Utility closure to generate random array + let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { + let values: Vec = (0..num_elems) + .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) + .collect(); + Arc::new(Float64Array::from_iter_values(values)) + }; + + // Fill constant columns + for constant in &eq_properties.constants { + let col = constant.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) + as ArrayRef; + schema_vec[idx] = Some(arr); + } + + // Fill columns based on ordering equivalences + for ordering in eq_properties.oeq_class.iter() { + let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering + .iter() + .map(|PhysicalSortExpr { expr, options }| { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = generate_random_array(n_elem, n_distinct); + ( + SortColumn { + values: arr, + options: Some(*options), + }, + idx, + ) + }) + .unzip(); + + let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + for (idx, arr) in izip!(indices, sort_arrs) { + schema_vec[idx] = Some(arr); + } + } + + // Fill columns based on equivalence groups + for eq_group in eq_properties.eq_group.iter() { + let representative_array = + get_representative_arr(eq_group, &schema_vec, schema.clone()) + .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); + + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + schema_vec[idx] = Some(representative_array.clone()); + } + } + + let res: Vec<_> = schema_vec + .into_iter() + .zip(schema.fields.iter()) + .map(|(elem, field)| { + ( + field.name(), + // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) + elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), + ) + }) + .collect(); + + Ok(RecordBatch::try_from_iter(res)?) + } +} diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs new file mode 100644 index 000000000000..1a414592ce4c --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -0,0 +1,1159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::SortOptions; +use std::hash::Hash; +use std::sync::Arc; + +use crate::equivalence::add_offset_to_expr; +use crate::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; + +/// An `OrderingEquivalenceClass` object keeps track of different alternative +/// orderings than can describe a schema. For example, consider the following table: +/// +/// ```text +/// |a|b|c|d| +/// |1|4|3|1| +/// |2|3|3|2| +/// |3|1|2|2| +/// |3|2|1|3| +/// ``` +/// +/// Here, both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the table +/// ordering. In this case, we say that these orderings are equivalent. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct OrderingEquivalenceClass { + pub orderings: Vec, +} + +impl OrderingEquivalenceClass { + /// Creates new empty ordering equivalence class. + pub fn empty() -> Self { + Self { orderings: vec![] } + } + + /// Clears (empties) this ordering equivalence class. + pub fn clear(&mut self) { + self.orderings.clear(); + } + + /// Creates new ordering equivalence class from the given orderings. + pub fn new(orderings: Vec) -> Self { + let mut result = Self { orderings }; + result.remove_redundant_entries(); + result + } + + /// Checks whether `ordering` is a member of this equivalence class. + pub fn contains(&self, ordering: &LexOrdering) -> bool { + self.orderings.contains(ordering) + } + + /// Adds `ordering` to this equivalence class. + #[allow(dead_code)] + fn push(&mut self, ordering: LexOrdering) { + self.orderings.push(ordering); + // Make sure that there are no redundant orderings: + self.remove_redundant_entries(); + } + + /// Checks whether this ordering equivalence class is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns an iterator over the equivalent orderings in this class. + pub fn iter(&self) -> impl Iterator { + self.orderings.iter() + } + + /// Returns how many equivalent orderings there are in this class. + pub fn len(&self) -> usize { + self.orderings.len() + } + + /// Extend this ordering equivalence class with the `other` class. + pub fn extend(&mut self, other: Self) { + self.orderings.extend(other.orderings); + // Make sure that there are no redundant orderings: + self.remove_redundant_entries(); + } + + /// Adds new orderings into this ordering equivalence class. + pub fn add_new_orderings( + &mut self, + orderings: impl IntoIterator, + ) { + self.orderings.extend(orderings); + // Make sure that there are no redundant orderings: + self.remove_redundant_entries(); + } + + /// Removes redundant orderings from this equivalence class. For instance, + /// if we already have the ordering `[a ASC, b ASC, c DESC]`, then there is + /// no need to keep ordering `[a ASC, b ASC]` in the state. + fn remove_redundant_entries(&mut self) { + let mut work = true; + while work { + work = false; + let mut idx = 0; + while idx < self.orderings.len() { + let mut ordering_idx = idx + 1; + let mut removal = self.orderings[idx].is_empty(); + while ordering_idx < self.orderings.len() { + work |= resolve_overlap(&mut self.orderings, idx, ordering_idx); + if self.orderings[idx].is_empty() { + removal = true; + break; + } + work |= resolve_overlap(&mut self.orderings, ordering_idx, idx); + if self.orderings[ordering_idx].is_empty() { + self.orderings.swap_remove(ordering_idx); + } else { + ordering_idx += 1; + } + } + if removal { + self.orderings.swap_remove(idx); + } else { + idx += 1; + } + } + } + } + + /// Returns the concatenation of all the orderings. This enables merge + /// operations to preserve all equivalent orderings simultaneously. + pub fn output_ordering(&self) -> Option { + let output_ordering = self.orderings.iter().flatten().cloned().collect(); + let output_ordering = collapse_lex_ordering(output_ordering); + (!output_ordering.is_empty()).then_some(output_ordering) + } + + // Append orderings in `other` to all existing orderings in this equivalence + // class. + pub fn join_suffix(mut self, other: &Self) -> Self { + let n_ordering = self.orderings.len(); + // Replicate entries before cross product + let n_cross = std::cmp::max(n_ordering, other.len() * n_ordering); + self.orderings = self + .orderings + .iter() + .cloned() + .cycle() + .take(n_cross) + .collect(); + // Suffix orderings of other to the current orderings. + for (outer_idx, ordering) in other.iter().enumerate() { + for idx in 0..n_ordering { + // Calculate cross product index + let idx = outer_idx * n_ordering + idx; + self.orderings[idx].extend(ordering.iter().cloned()); + } + } + self + } + + /// Adds `offset` value to the index of each expression inside this + /// ordering equivalence class. + pub fn add_offset(&mut self, offset: usize) { + for ordering in self.orderings.iter_mut() { + for sort_expr in ordering { + sort_expr.expr = add_offset_to_expr(sort_expr.expr.clone(), offset); + } + } + } + + /// Gets sort options associated with this expression if it is a leading + /// ordering expression. Otherwise, returns `None`. + pub fn get_options(&self, expr: &Arc) -> Option { + for ordering in self.iter() { + let leading_ordering = &ordering[0]; + if leading_ordering.expr.eq(expr) { + return Some(leading_ordering.options); + } + } + None + } +} + +/// This function constructs a duplicate-free `LexOrdering` by filtering out +/// duplicate entries that have same physical expression inside. For example, +/// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. +pub fn collapse_lex_ordering(input: LexOrdering) -> LexOrdering { + let mut output = Vec::::new(); + for item in input { + if !output.iter().any(|req| req.expr.eq(&item.expr)) { + output.push(item); + } + } + output +} + +/// Trims `orderings[idx]` if some suffix of it overlaps with a prefix of +/// `orderings[pre_idx]`. Returns `true` if there is any overlap, `false` otherwise. +fn resolve_overlap(orderings: &mut [LexOrdering], idx: usize, pre_idx: usize) -> bool { + let length = orderings[idx].len(); + let other_length = orderings[pre_idx].len(); + for overlap in 1..=length.min(other_length) { + if orderings[idx][length - overlap..] == orderings[pre_idx][..overlap] { + orderings[idx].truncate(length - overlap); + return true; + } + } + false +} + +#[cfg(test)] +mod tests { + use crate::equivalence::tests::{ + convert_to_orderings, convert_to_sort_exprs, create_random_schema, + create_test_params, generate_table_for_eq_properties, is_table_same_after_sort, + }; + use crate::equivalence::{tests::create_test_schema, EquivalenceProperties}; + use crate::equivalence::{ + EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, + }; + use crate::execution_props::ExecutionProps; + use crate::expressions::Column; + use crate::expressions::{col, BinaryExpr}; + use crate::functions::create_physical_expr; + use crate::{PhysicalExpr, PhysicalSortExpr}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::SortOptions; + use datafusion_common::Result; + use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; + use std::sync::Arc; + + #[test] + fn test_ordering_satisfy() -> Result<()> { + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + ])); + let crude = vec![PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }]; + let finer = vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + ]; + // finer ordering satisfies, crude ordering should return true + let mut eq_properties_finer = EquivalenceProperties::new(input_schema.clone()); + eq_properties_finer.oeq_class.push(finer.clone()); + assert!(eq_properties_finer.ordering_satisfy(&crude)); + + // Crude ordering doesn't satisfy finer ordering. should return false + let mut eq_properties_crude = EquivalenceProperties::new(input_schema.clone()); + eq_properties_crude.oeq_class.push(crude.clone()); + assert!(!eq_properties_crude.ordering_satisfy(&finer)); + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence2() -> Result<()> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let floor_a = &create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let floor_f = &create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("f", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let exp_a = &create_physical_expr( + &BuiltinScalarFunction::Exp, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + let options = SortOptions { + descending: false, + nulls_first: false, + }; + + let test_cases = vec![ + // ------------ TEST CASE 1 ------------ + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![(col_a, options), (col_d, options), (col_b, options)], + // [c ASC] + vec![(col_c, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, b ASC], requirement is not satisfied. + vec![(col_a, options), (col_b, options)], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 2 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(a) ASC], + vec![(floor_a, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 2.1 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(f) ASC], (Please note that a=f) + vec![(floor_f, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 3 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, c ASC, a+b ASC], + vec![(col_a, options), (col_c, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(a) ASC, a+b ASC], + vec![(floor_a, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + false, + ), + // ------------ TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [exp(a) ASC, a+b ASC], + vec![(exp_a, options), (&a_plus_b, options)], + // expected: requirement is not satisfied. + // TODO: If we know that exp function is 1-to-1 function. + // we could have deduced that above requirement is satisfied. + false, + ), + // ------------ TEST CASE 6 ------------ + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![(col_a, options), (col_d, options), (col_b, options)], + // [c ASC] + vec![(col_c, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, d ASC, floor(a) ASC], + vec![(col_a, options), (col_d, options), (floor_a, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 7 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, floor(a) ASC, a + b ASC], + vec![(col_a, options), (floor_a, options), (&a_plus_b, options)], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 8 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![(col_a, options), (col_b, options), (col_c, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, c ASC, floor(a) ASC, a + b ASC], + vec![ + (col_a, options), + (col_c, options), + (&floor_a, options), + (&a_plus_b, options), + ], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 9 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, b ASC, c ASC, floor(a) ASC], + vec![ + (col_a, options), + (col_b, options), + (&col_c, options), + (&floor_a, options), + ], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 10 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, options), (col_b, options)], + // [c ASC, a ASC] + vec![(col_c, options), (col_a, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [c ASC, d ASC, a + b ASC], + vec![(col_c, options), (col_d, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + true, + ), + ]; + + for (orderings, eq_group, constants, reqs, expected) in test_cases { + let err_msg = + format!("error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}"); + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + let eq_group = eq_group + .into_iter() + .map(|eq_class| { + let eq_classes = eq_class.into_iter().cloned().collect::>(); + EquivalenceClass::new(eq_classes) + }) + .collect::>(); + let eq_group = EquivalenceGroup::new(eq_group); + eq_properties.add_equivalence_group(eq_group); + + let constants = constants.into_iter().cloned(); + eq_properties = eq_properties.add_constants(constants); + + let reqs = convert_to_sort_exprs(&reqs); + assert_eq!( + eq_properties.ordering_satisfy(&reqs), + expected, + "{}", + err_msg + ); + } + + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence() -> Result<()> { + // Schema satisfies following orderings: + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + // and + // Column [a=c] (e.g they are aliases). + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, 625, 5)?; + + // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function + let requirements = vec![ + // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it + (vec![(col_a, option_asc)], true), + (vec![(col_a, option_desc)], false), + // Test whether equivalence works as expected + (vec![(col_c, option_asc)], true), + (vec![(col_c, option_desc)], false), + // Test whether ordering equivalence works as expected + (vec![(col_d, option_asc)], true), + (vec![(col_d, option_asc), (col_b, option_asc)], true), + (vec![(col_d, option_desc), (col_b, option_asc)], false), + ( + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + true, + ), + (vec![(col_e, option_desc), (col_f, option_asc)], true), + (vec![(col_e, option_asc), (col_f, option_asc)], false), + (vec![(col_e, option_desc), (col_b, option_asc)], false), + (vec![(col_e, option_asc), (col_b, option_asc)], false), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_f, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_f, option_asc), + ], + false, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_b, option_asc), + ], + false, + ), + (vec![(col_d, option_asc), (col_e, option_desc)], true), + ( + vec![ + (col_d, option_asc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_f, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_b, option_asc), + (col_f, option_asc), + ], + true, + ), + ]; + + for (cols, expected) in requirements { + let err_msg = format!("Error in test case:{cols:?}"); + let required = cols + .into_iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: expr.clone(), + options, + }) + .collect::>(); + + // Check expected result with experimental result. + assert_eq!( + is_table_same_after_sort( + required.clone(), + table_data_with_properties.clone() + )?, + expected + ); + assert_eq!( + eq_properties.ordering_satisfy(&required), + expected, + "{err_msg}" + ); + } + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 5; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + let col_exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + ]; + + for n_req in 0..=col_exprs.len() { + for exprs in col_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + eq_properties.ordering_satisfy(&requirement), + expected, + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + + assert_eq!( + eq_properties.ordering_satisfy(&requirement), + (expected | false), + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + + #[test] + fn test_ordering_satisfy_different_lengths() -> Result<()> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let options = SortOptions { + descending: false, + nulls_first: false, + }; + // a=c (e.g they are aliases). + let mut eq_properties = EquivalenceProperties::new(test_schema); + eq_properties.add_equal_conditions(col_a, col_c); + + let orderings = vec![ + vec![(col_a, options)], + vec![(col_e, options)], + vec![(col_d, options), (col_f, options)], + ]; + let orderings = convert_to_orderings(&orderings); + + // Column [a ASC], [e ASC], [d ASC, f ASC] are all valid orderings for the schema. + eq_properties.add_new_orderings(orderings); + + // First entry in the tuple is required ordering, second entry is the expected flag + // that indicates whether this required ordering is satisfied. + // ([a ASC], true) indicate a ASC requirement is already satisfied by existing orderings. + let test_cases = vec![ + // [c ASC, a ASC, e ASC], expected represents this requirement is satisfied + ( + vec![(col_c, options), (col_a, options), (col_e, options)], + true, + ), + (vec![(col_c, options), (col_b, options)], false), + (vec![(col_c, options), (col_d, options)], true), + ( + vec![(col_d, options), (col_f, options), (col_b, options)], + false, + ), + (vec![(col_d, options), (col_f, options)], true), + ]; + + for (reqs, expected) in test_cases { + let err_msg = + format!("error in test reqs: {:?}, expected: {:?}", reqs, expected,); + let reqs = convert_to_sort_exprs(&reqs); + assert_eq!( + eq_properties.ordering_satisfy(&reqs), + expected, + "{}", + err_msg + ); + } + + Ok(()) + } + + #[test] + fn test_remove_redundant_entries_oeq_class() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + + // First entry in the tuple is the given orderings for the table + // Second entry is the simplest version of the given orderings that is functionally equivalent. + let test_cases = vec![ + // ------- TEST CASE 1 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + ], + ), + // ------- TEST CASE 2 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + ), + // ------- TEST CASE 3 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b DESC] + vec![(col_a, option_asc), (col_b, option_desc)], + // [a ASC] + vec![(col_a, option_asc)], + // [a ASC, c ASC] + vec![(col_a, option_asc), (col_c, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b DESC] + vec![(col_a, option_asc), (col_b, option_desc)], + // [a ASC, c ASC] + vec![(col_a, option_asc), (col_c, option_asc)], + ], + ), + // ------- TEST CASE 4 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [a ASC] + vec![(col_a, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + ), + // ------- TEST CASE 5 --------- + // Empty ordering + ( + vec![vec![]], + // No ordering in the state (empty ordering is ignored). + vec![], + ), + // ------- TEST CASE 6 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [b ASC] + vec![(col_b, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC] + vec![(col_a, option_asc)], + // [b ASC] + vec![(col_b, option_asc)], + ], + ), + // ------- TEST CASE 7 --------- + // b, a + // c, a + // d, b, c + ( + // ORDERINGS GIVEN + vec![ + // [b ASC, a ASC] + vec![(col_b, option_asc), (col_a, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC, b ASC, c ASC] + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC, a ASC] + vec![(col_b, option_asc), (col_a, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + // ------- TEST CASE 8 --------- + // b, e + // c, a + // d, b, e, c, a + ( + // ORDERINGS GIVEN + vec![ + // [b ASC, e ASC] + vec![(col_b, option_asc), (col_e, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC, b ASC, e ASC, c ASC, a ASC] + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_c, option_asc), + (col_a, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC, e ASC] + vec![(col_b, option_asc), (col_e, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + // ------- TEST CASE 9 --------- + // b + // a, b, c + // d, a, b + ( + // ORDERINGS GIVEN + vec![ + // [b ASC] + vec![(col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [d ASC, a ASC, b ASC] + vec![ + (col_d, option_asc), + (col_a, option_asc), + (col_b, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC] + vec![(col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + ]; + for (orderings, expected) in test_cases { + let orderings = convert_to_orderings(&orderings); + let expected = convert_to_orderings(&expected); + let actual = OrderingEquivalenceClass::new(orderings.clone()); + let actual = actual.orderings; + let err_msg = format!( + "orderings: {:?}, expected: {:?}, actual :{:?}", + orderings, expected, actual + ); + assert_eq!(actual.len(), expected.len(), "{}", err_msg); + for elem in actual { + assert!(expected.contains(&elem), "{}", err_msg); + } + } + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs new file mode 100644 index 000000000000..0f92b2c2f431 --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -0,0 +1,1153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::expressions::Column; +use crate::PhysicalExpr; + +use arrow::datatypes::SchemaRef; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; + +/// Stores the mapping between source expressions and target expressions for a +/// projection. +#[derive(Debug, Clone)] +pub struct ProjectionMapping { + /// Mapping between source expressions and target expressions. + /// Vector indices correspond to the indices after projection. + pub map: Vec<(Arc, Arc)>, +} + +impl ProjectionMapping { + /// Constructs the mapping between a projection's input and output + /// expressions. + /// + /// For example, given the input projection expressions (`a + b`, `c + d`) + /// and an output schema with two columns `"c + d"` and `"a + b"`, the + /// projection mapping would be: + /// + /// ```text + /// [0]: (c + d, col("c + d")) + /// [1]: (a + b, col("a + b")) + /// ``` + /// + /// where `col("c + d")` means the column named `"c + d"`. + pub fn try_new( + expr: &[(Arc, String)], + input_schema: &SchemaRef, + ) -> Result { + // Construct a map from the input expressions to the output expression of the projection: + expr.iter() + .enumerate() + .map(|(expr_idx, (expression, name))| { + let target_expr = Arc::new(Column::new(name, expr_idx)) as _; + expression + .clone() + .transform_down(&|e| match e.as_any().downcast_ref::() { + Some(col) => { + // Sometimes, an expression and its name in the input_schema + // doesn't match. This can cause problems, so we make sure + // that the expression name matches with the name in `input_schema`. + // Conceptually, `source_expr` and `expression` should be the same. + let idx = col.index(); + let matching_input_field = input_schema.field(idx); + let matching_input_column = + Column::new(matching_input_field.name(), idx); + Ok(Transformed::Yes(Arc::new(matching_input_column))) + } + None => Ok(Transformed::No(e)), + }) + .map(|source_expr| (source_expr, target_expr)) + }) + .collect::>>() + .map(|map| Self { map }) + } + + /// Iterate over pairs of (source, target) expressions + pub fn iter( + &self, + ) -> impl Iterator, Arc)> + '_ { + self.map.iter() + } + + /// This function returns the target expression for a given source expression. + /// + /// # Arguments + /// + /// * `expr` - Source physical expression. + /// + /// # Returns + /// + /// An `Option` containing the target for the given source expression, + /// where a `None` value means that `expr` is not inside the mapping. + pub fn target_expr( + &self, + expr: &Arc, + ) -> Option> { + self.map + .iter() + .find(|(source, _)| source.eq(expr)) + .map(|(_, target)| target.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::equivalence::tests::{ + apply_projection, convert_to_orderings, convert_to_orderings_owned, + create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, + output_schema, + }; + use crate::equivalence::EquivalenceProperties; + use crate::execution_props::ExecutionProps; + use crate::expressions::{col, BinaryExpr, Literal}; + use crate::functions::create_physical_expr; + use crate::PhysicalSortExpr; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{SortOptions, TimeUnit}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; + use std::sync::Arc; + + #[test] + fn project_orderings() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + let col_ts = &col("ts", &schema)?; + let interval = Arc::new(Literal::new(ScalarValue::IntervalDayTime(Some(2)))) + as Arc; + let date_bin_func = &create_physical_expr( + &BuiltinScalarFunction::DateBin, + &[interval, col_ts.clone()], + &schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + let b_plus_d = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + let b_plus_e = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_e.clone(), + )) as Arc; + let c_plus_d = Arc::new(BinaryExpr::new( + col_c.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + + let test_cases = vec![ + // ---------- TEST CASE 1 ------------ + ( + // orderings + vec![ + // [b ASC] + vec![(col_b, option_asc)], + ], + // projection exprs + vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())], + // expected + vec![ + // [b_new ASC] + vec![("b_new", option_asc)], + ], + ), + // ---------- TEST CASE 2 ------------ + ( + // orderings + vec![ + // empty ordering + ], + // projection exprs + vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())], + // expected + vec![ + // no ordering at the output + ], + ), + // ---------- TEST CASE 3 ------------ + ( + // orderings + vec![ + // [ts ASC] + vec![(col_ts, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_ts, "ts_new".to_string()), + (date_bin_func, "date_bin_res".to_string()), + ], + // expected + vec![ + // [date_bin_res ASC] + vec![("date_bin_res", option_asc)], + // [ts_new ASC] + vec![("ts_new", option_asc)], + ], + ), + // ---------- TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, ts ASC] + vec![(col_a, option_asc), (col_ts, option_asc)], + // [b ASC, ts ASC] + vec![(col_b, option_asc), (col_ts, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_ts, "ts_new".to_string()), + (date_bin_func, "date_bin_res".to_string()), + ], + // expected + vec![ + // [a_new ASC, ts_new ASC] + vec![("a_new", option_asc), ("ts_new", option_asc)], + // [a_new ASC, date_bin_res ASC] + vec![("a_new", option_asc), ("date_bin_res", option_asc)], + // [b_new ASC, ts_new ASC] + vec![("b_new", option_asc), ("ts_new", option_asc)], + // [b_new ASC, date_bin_res ASC] + vec![("b_new", option_asc), ("date_bin_res", option_asc)], + ], + ), + // ---------- TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a + b ASC] + vec![(&a_plus_b, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a + b ASC] + vec![("a+b", option_asc)], + ], + ), + // ---------- TEST CASE 6 ------------ + ( + // orderings + vec![ + // [a + b ASC, c ASC] + vec![(&a_plus_b, option_asc), (&col_c, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_c, "c_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a + b ASC, c_new ASC] + vec![("a+b", option_asc), ("c_new", option_asc)], + ], + ), + // ------- TEST CASE 7 ---------- + ( + vec![ + // [a ASC, b ASC, c ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, d ASC] + vec![(col_a, option_asc), (col_d, option_asc)], + ], + // b as b_new, a as a_new, d as d_new b+d + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_d, "d_new".to_string()), + (&b_plus_d, "b+d".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b_new", option_asc)], + // [a_new ASC, d_new ASC] + vec![("a_new", option_asc), ("d_new", option_asc)], + // [a_new ASC, b+d ASC] + vec![("a_new", option_asc), ("b+d", option_asc)], + ], + ), + // ------- TEST CASE 8 ---------- + ( + // orderings + vec![ + // [b+d ASC] + vec![(&b_plus_d, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_d, "d_new".to_string()), + (&b_plus_d, "b+d".to_string()), + ], + // expected + vec![ + // [b+d ASC] + vec![("b+d", option_asc)], + ], + ), + // ------- TEST CASE 9 ---------- + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![ + (col_a, option_asc), + (col_d, option_asc), + (col_b, option_asc), + ], + // [c ASC] + vec![(col_c, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_d, "d_new".to_string()), + (col_c, "c_new".to_string()), + ], + // expected + vec![ + // [a_new ASC, d_new ASC, b_new ASC] + vec![ + ("a_new", option_asc), + ("d_new", option_asc), + ("b_new", option_asc), + ], + // [c_new ASC], + vec![("c_new", option_asc)], + ], + ), + // ------- TEST CASE 10 ---------- + ( + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [a ASC, d ASC] + vec![(col_a, option_asc), (col_d, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_c, "c_new".to_string()), + (&c_plus_d, "c+d".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC, c_new ASC] + vec![ + ("a_new", option_asc), + ("b_new", option_asc), + ("c_new", option_asc), + ], + // [a_new ASC, b_new ASC, c+d ASC] + vec![ + ("a_new", option_asc), + ("b_new", option_asc), + ("c+d", option_asc), + ], + ], + ), + // ------- TEST CASE 11 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, d ASC] + vec![(col_a, option_asc), (col_d, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (&b_plus_d, "b+d".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b_new", option_asc)], + // [a_new ASC, b + d ASC] + vec![("a_new", option_asc), ("b+d", option_asc)], + ], + ), + // ------- TEST CASE 12 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // proj exprs + vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())], + // expected + vec![ + // [a_new ASC] + vec![("a_new", option_asc)], + ], + ), + // ------- TEST CASE 13 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [a ASC, a + b ASC, c ASC] + vec![ + (col_a, option_asc), + (&a_plus_b, option_asc), + (col_c, option_asc), + ], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC, c_new ASC] + vec![ + ("a_new", option_asc), + ("b_new", option_asc), + ("c_new", option_asc), + ], + // [a_new ASC, a+b ASC, c_new ASC] + vec![ + ("a_new", option_asc), + ("a+b", option_asc), + ("c_new", option_asc), + ], + ], + ), + // ------- TEST CASE 14 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [c ASC, b ASC] + vec![(col_c, option_asc), (col_b, option_asc)], + // [d ASC, e ASC] + vec![(col_d, option_asc), (col_e, option_asc)], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_d, "d_new".to_string()), + (col_a, "a_new".to_string()), + (&b_plus_e, "b+e".to_string()), + ], + // expected + vec![ + // [a_new ASC, d_new ASC, b+e ASC] + vec![ + ("a_new", option_asc), + ("d_new", option_asc), + ("b+e", option_asc), + ], + // [d_new ASC, a_new ASC, b+e ASC] + vec![ + ("d_new", option_asc), + ("a_new", option_asc), + ("b+e", option_asc), + ], + // [c_new ASC, d_new ASC, b+e ASC] + vec![ + ("c_new", option_asc), + ("d_new", option_asc), + ("b+e", option_asc), + ], + // [d_new ASC, c_new ASC, b+e ASC] + vec![ + ("d_new", option_asc), + ("c_new", option_asc), + ("b+e", option_asc), + ], + ], + ), + // ------- TEST CASE 15 ---------- + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![ + (col_a, option_asc), + (col_c, option_asc), + (&col_b, option_asc), + ], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_a, "a_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a_new ASC, d_new ASC, b+e ASC] + vec![ + ("a_new", option_asc), + ("c_new", option_asc), + ("a+b", option_asc), + ], + ], + ), + // ------- TEST CASE 16 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [c ASC, b DESC] + vec![(col_c, option_asc), (col_b, option_desc)], + // [e ASC] + vec![(col_e, option_asc)], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_a, "a_new".to_string()), + (col_b, "b_new".to_string()), + (&b_plus_e, "b+e".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b_new", option_asc)], + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b+e", option_asc)], + // [c_new ASC, b_new DESC] + vec![("c_new", option_asc), ("b_new", option_desc)], + ], + ), + ]; + + for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate() + { + let mut eq_properties = EquivalenceProperties::new(schema.clone()); + + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name)) + .collect::>(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + let output_schema = output_schema(&projection_mapping, &schema)?; + + let expected = expected + .into_iter() + .map(|ordering| { + ordering + .into_iter() + .map(|(name, options)| { + (col(name, &output_schema).unwrap(), options) + }) + .collect::>() + }) + .collect::>(); + let expected = convert_to_orderings_owned(&expected); + + let projected_eq = eq_properties.project(&projection_mapping, output_schema); + let orderings = projected_eq.oeq_class(); + + let err_msg = format!( + "test_idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", + idx, orderings.orderings, expected, projection_mapping + ); + + assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + for expected_ordering in &expected { + assert!(orderings.contains(expected_ordering), "{}", err_msg) + } + } + + Ok(()) + } + + #[test] + fn project_orderings2() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_ts = &col("ts", &schema)?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + let interval = Arc::new(Literal::new(ScalarValue::IntervalDayTime(Some(2)))) + as Arc; + let date_bin_ts = &create_physical_expr( + &BuiltinScalarFunction::DateBin, + &[interval, col_ts.clone()], + &schema, + &ExecutionProps::default(), + )?; + + let round_c = &create_physical_expr( + &BuiltinScalarFunction::Round, + &[col_c.clone()], + &schema, + &ExecutionProps::default(), + )?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let proj_exprs = vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_c, "c_new".to_string()), + (date_bin_ts, "date_bin_res".to_string()), + (round_c, "round_c_res".to_string()), + ]; + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name)) + .collect::>(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + let output_schema = output_schema(&projection_mapping, &schema)?; + + let col_a_new = &col("a_new", &output_schema)?; + let col_b_new = &col("b_new", &output_schema)?; + let col_c_new = &col("c_new", &output_schema)?; + let col_date_bin_res = &col("date_bin_res", &output_schema)?; + let col_round_c_res = &col("round_c_res", &output_schema)?; + let a_new_plus_b_new = Arc::new(BinaryExpr::new( + col_a_new.clone(), + Operator::Plus, + col_b_new.clone(), + )) as Arc; + + let test_cases = vec![ + // ---------- TEST CASE 1 ------------ + ( + // orderings + vec![ + // [a ASC] + vec![(col_a, option_asc)], + ], + // expected + vec![ + // [b_new ASC] + vec![(col_a_new, option_asc)], + ], + ), + // ---------- TEST CASE 2 ------------ + ( + // orderings + vec![ + // [a+b ASC] + vec![(&a_plus_b, option_asc)], + ], + // expected + vec![ + // [b_new ASC] + vec![(&a_new_plus_b_new, option_asc)], + ], + ), + // ---------- TEST CASE 3 ------------ + ( + // orderings + vec![ + // [a ASC, ts ASC] + vec![(col_a, option_asc), (col_ts, option_asc)], + ], + // expected + vec![ + // [a_new ASC, date_bin_res ASC] + vec![(col_a_new, option_asc), (col_date_bin_res, option_asc)], + ], + ), + // ---------- TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, ts ASC, b ASC] + vec![ + (col_a, option_asc), + (col_ts, option_asc), + (col_b, option_asc), + ], + ], + // expected + vec![ + // [a_new ASC, date_bin_res ASC] + // Please note that result is not [a_new ASC, date_bin_res ASC, b_new ASC] + // because, datebin_res may not be 1-1 function. Hence without introducing ts + // dependency we cannot guarantee any ordering after date_bin_res column. + vec![(col_a_new, option_asc), (col_date_bin_res, option_asc)], + ], + ), + // ---------- TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC] + vec![(col_a, option_asc), (col_c, option_asc)], + ], + // expected + vec![ + // [a_new ASC, round_c_res ASC, c_new ASC] + vec![(col_a_new, option_asc), (col_round_c_res, option_asc)], + // [a_new ASC, c_new ASC] + vec![(col_a_new, option_asc), (col_c_new, option_asc)], + ], + ), + // ---------- TEST CASE 6 ------------ + ( + // orderings + vec![ + // [c ASC, b ASC] + vec![(col_c, option_asc), (col_b, option_asc)], + ], + // expected + vec![ + // [round_c_res ASC] + vec![(col_round_c_res, option_asc)], + // [c_new ASC, b_new ASC] + vec![(col_c_new, option_asc), (col_b_new, option_asc)], + ], + ), + // ---------- TEST CASE 7 ------------ + ( + // orderings + vec![ + // [a+b ASC, c ASC] + vec![(&a_plus_b, option_asc), (col_c, option_asc)], + ], + // expected + vec![ + // [a+b ASC, round(c) ASC, c_new ASC] + vec![ + (&a_new_plus_b_new, option_asc), + (&col_round_c_res, option_asc), + ], + // [a+b ASC, c_new ASC] + vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)], + ], + ), + ]; + + for (idx, (orderings, expected)) in test_cases.iter().enumerate() { + let mut eq_properties = EquivalenceProperties::new(schema.clone()); + + let orderings = convert_to_orderings(orderings); + eq_properties.add_new_orderings(orderings); + + let expected = convert_to_orderings(expected); + + let projected_eq = + eq_properties.project(&projection_mapping, output_schema.clone()); + let orderings = projected_eq.oeq_class(); + + let err_msg = format!( + "test idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", + idx, orderings.orderings, expected, projection_mapping + ); + + assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + for expected_ordering in &expected { + assert!(orderings.contains(expected_ordering), "{}", err_msg) + } + } + Ok(()) + } + + #[test] + fn project_orderings3() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Int32, true), + ])); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + let col_f = &col("f", &schema)?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let proj_exprs = vec![ + (col_c, "c_new".to_string()), + (col_d, "d_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ]; + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name)) + .collect::>(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + let output_schema = output_schema(&projection_mapping, &schema)?; + + let col_a_plus_b_new = &col("a+b", &output_schema)?; + let col_c_new = &col("c_new", &output_schema)?; + let col_d_new = &col("d_new", &output_schema)?; + + let test_cases = vec![ + // ---------- TEST CASE 1 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + ], + // equal conditions + vec![], + // expected + vec![ + // [d_new ASC, c_new ASC, a+b ASC] + vec![ + (col_d_new, option_asc), + (col_c_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + // [c_new ASC, d_new ASC, a+b ASC] + vec![ + (col_c_new, option_asc), + (col_d_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + ], + ), + // ---------- TEST CASE 2 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [c ASC, e ASC], Please note that a=e + vec![(col_c, option_asc), (col_e, option_asc)], + ], + // equal conditions + vec![(col_e, col_a)], + // expected + vec![ + // [d_new ASC, c_new ASC, a+b ASC] + vec![ + (col_d_new, option_asc), + (col_c_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + // [c_new ASC, d_new ASC, a+b ASC] + vec![ + (col_c_new, option_asc), + (col_d_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + ], + ), + // ---------- TEST CASE 3 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [c ASC, e ASC], Please note that a=f + vec![(col_c, option_asc), (col_e, option_asc)], + ], + // equal conditions + vec![(col_a, col_f)], + // expected + vec![ + // [d_new ASC] + vec![(col_d_new, option_asc)], + // [c_new ASC] + vec![(col_c_new, option_asc)], + ], + ), + ]; + for (orderings, equal_columns, expected) in test_cases { + let mut eq_properties = EquivalenceProperties::new(schema.clone()); + for (lhs, rhs) in equal_columns { + eq_properties.add_equal_conditions(lhs, rhs); + } + + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + + let expected = convert_to_orderings(&expected); + + let projected_eq = + eq_properties.project(&projection_mapping, output_schema.clone()); + let orderings = projected_eq.oeq_class(); + + let err_msg = format!( + "actual: {:?}, expected: {:?}, projection_mapping: {:?}", + orderings.orderings, expected, projection_mapping + ); + + assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + for expected_ordering in &expected { + assert!(orderings.contains(expected_ordering), "{}", err_msg) + } + } + + Ok(()) + } + + #[test] + fn project_orderings_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + // Make sure each ordering after projection is valid. + for ordering in projected_eq.oeq_class().iter() { + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs + ); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + projected_batch.clone(), + )?, + "{}", + err_msg + ); + } + } + } + } + + Ok(()) + } + + #[test] + fn ordering_satisfy_after_projection_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + let projection_mapping = + ProjectionMapping::try_new(&proj_exprs, &test_schema)?; + + let projected_exprs = projection_mapping + .iter() + .map(|(_source, target)| target.clone()) + .collect::>(); + + for n_req in 0..=projected_exprs.len() { + for exprs in projected_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + projected_batch.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + projected_eq.ordering_satisfy(&requirement), + expected, + "{}", + err_msg + ); + } + } + } + } + } + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs new file mode 100644 index 000000000000..31c1cf61193a --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -0,0 +1,2062 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::Column; +use arrow_schema::SchemaRef; +use datafusion_common::{JoinSide, JoinType}; +use indexmap::IndexSet; +use itertools::Itertools; +use std::collections::{HashMap, HashSet}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use crate::equivalence::{ + collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, +}; + +use crate::expressions::Literal; +use crate::sort_properties::{ExprOrdering, SortProperties}; +use crate::{ + physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, + LexRequirementRef, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, +}; +use datafusion_common::tree_node::{Transformed, TreeNode}; + +use super::ordering::collapse_lex_ordering; + +/// A `EquivalenceProperties` object stores useful information related to a schema. +/// Currently, it keeps track of: +/// - Equivalent expressions, e.g expressions that have same value. +/// - Valid sort expressions (orderings) for the schema. +/// - Constants expressions (e.g expressions that are known to have constant values). +/// +/// Consider table below: +/// +/// ```text +/// ┌-------┐ +/// | a | b | +/// |---|---| +/// | 1 | 9 | +/// | 2 | 8 | +/// | 3 | 7 | +/// | 5 | 5 | +/// └---┴---┘ +/// ``` +/// +/// where both `a ASC` and `b DESC` can describe the table ordering. With +/// `EquivalenceProperties`, we can keep track of these different valid sort +/// expressions and treat `a ASC` and `b DESC` on an equal footing. +/// +/// Similarly, consider the table below: +/// +/// ```text +/// ┌-------┐ +/// | a | b | +/// |---|---| +/// | 1 | 1 | +/// | 2 | 2 | +/// | 3 | 3 | +/// | 5 | 5 | +/// └---┴---┘ +/// ``` +/// +/// where columns `a` and `b` always have the same value. We keep track of such +/// equivalences inside this object. With this information, we can optimize +/// things like partitioning. For example, if the partition requirement is +/// `Hash(a)` and output partitioning is `Hash(b)`, then we can deduce that +/// the existing partitioning satisfies the requirement. +#[derive(Debug, Clone)] +pub struct EquivalenceProperties { + /// Collection of equivalence classes that store expressions with the same + /// value. + pub eq_group: EquivalenceGroup, + /// Equivalent sort expressions for this table. + pub oeq_class: OrderingEquivalenceClass, + /// Expressions whose values are constant throughout the table. + /// TODO: We do not need to track constants separately, they can be tracked + /// inside `eq_groups` as `Literal` expressions. + pub constants: Vec>, + /// Schema associated with this object. + schema: SchemaRef, +} + +impl EquivalenceProperties { + /// Creates an empty `EquivalenceProperties` object. + pub fn new(schema: SchemaRef) -> Self { + Self { + eq_group: EquivalenceGroup::empty(), + oeq_class: OrderingEquivalenceClass::empty(), + constants: vec![], + schema, + } + } + + /// Creates a new `EquivalenceProperties` object with the given orderings. + pub fn new_with_orderings(schema: SchemaRef, orderings: &[LexOrdering]) -> Self { + Self { + eq_group: EquivalenceGroup::empty(), + oeq_class: OrderingEquivalenceClass::new(orderings.to_vec()), + constants: vec![], + schema, + } + } + + /// Returns the associated schema. + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + + /// Returns a reference to the ordering equivalence class within. + pub fn oeq_class(&self) -> &OrderingEquivalenceClass { + &self.oeq_class + } + + /// Returns a reference to the equivalence group within. + pub fn eq_group(&self) -> &EquivalenceGroup { + &self.eq_group + } + + /// Returns a reference to the constant expressions + pub fn constants(&self) -> &[Arc] { + &self.constants + } + + /// Returns the normalized version of the ordering equivalence class within. + /// Normalization removes constants and duplicates as well as standardizing + /// expressions according to the equivalence group within. + pub fn normalized_oeq_class(&self) -> OrderingEquivalenceClass { + OrderingEquivalenceClass::new( + self.oeq_class + .iter() + .map(|ordering| self.normalize_sort_exprs(ordering)) + .collect(), + ) + } + + /// Extends this `EquivalenceProperties` with the `other` object. + pub fn extend(mut self, other: Self) -> Self { + self.eq_group.extend(other.eq_group); + self.oeq_class.extend(other.oeq_class); + self.add_constants(other.constants) + } + + /// Clears (empties) the ordering equivalence class within this object. + /// Call this method when existing orderings are invalidated. + pub fn clear_orderings(&mut self) { + self.oeq_class.clear(); + } + + /// Extends this `EquivalenceProperties` by adding the orderings inside the + /// ordering equivalence class `other`. + pub fn add_ordering_equivalence_class(&mut self, other: OrderingEquivalenceClass) { + self.oeq_class.extend(other); + } + + /// Adds new orderings into the existing ordering equivalence class. + pub fn add_new_orderings( + &mut self, + orderings: impl IntoIterator, + ) { + self.oeq_class.add_new_orderings(orderings); + } + + /// Incorporates the given equivalence group to into the existing + /// equivalence group within. + pub fn add_equivalence_group(&mut self, other_eq_group: EquivalenceGroup) { + self.eq_group.extend(other_eq_group); + } + + /// Adds a new equality condition into the existing equivalence group. + /// If the given equality defines a new equivalence class, adds this new + /// equivalence class to the equivalence group. + pub fn add_equal_conditions( + &mut self, + left: &Arc, + right: &Arc, + ) { + self.eq_group.add_equal_conditions(left, right); + } + + /// Track/register physical expressions with constant values. + pub fn add_constants( + mut self, + constants: impl IntoIterator>, + ) -> Self { + for expr in self.eq_group.normalize_exprs(constants) { + if !physical_exprs_contains(&self.constants, &expr) { + self.constants.push(expr); + } + } + self + } + + /// Updates the ordering equivalence group within assuming that the table + /// is re-sorted according to the argument `sort_exprs`. Note that constants + /// and equivalence classes are unchanged as they are unaffected by a re-sort. + pub fn with_reorder(mut self, sort_exprs: Vec) -> Self { + // TODO: In some cases, existing ordering equivalences may still be valid add this analysis. + self.oeq_class = OrderingEquivalenceClass::new(vec![sort_exprs]); + self + } + + /// Normalizes the given sort expressions (i.e. `sort_exprs`) using the + /// equivalence group and the ordering equivalence class within. + /// + /// Assume that `self.eq_group` states column `a` and `b` are aliases. + /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` + /// are equivalent (in the sense that both describe the ordering of the table). + /// If the `sort_exprs` argument were `vec![b ASC, c ASC, a ASC]`, then this + /// function would return `vec![a ASC, c ASC]`. Internally, it would first + /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result + /// after deduplication. + fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { + // Convert sort expressions to sort requirements: + let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); + // Normalize the requirements: + let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); + // Convert sort requirements back to sort expressions: + PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) + } + + /// Normalizes the given sort requirements (i.e. `sort_reqs`) using the + /// equivalence group and the ordering equivalence class within. It works by: + /// - Removing expressions that have a constant value from the given requirement. + /// - Replacing sections that belong to some equivalence class in the equivalence + /// group with the first entry in the matching equivalence class. + /// + /// Assume that `self.eq_group` states column `a` and `b` are aliases. + /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` + /// are equivalent (in the sense that both describe the ordering of the table). + /// If the `sort_reqs` argument were `vec![b ASC, c ASC, a ASC]`, then this + /// function would return `vec![a ASC, c ASC]`. Internally, it would first + /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result + /// after deduplication. + fn normalize_sort_requirements( + &self, + sort_reqs: LexRequirementRef, + ) -> LexRequirement { + let normalized_sort_reqs = self.eq_group.normalize_sort_requirements(sort_reqs); + let constants_normalized = self.eq_group.normalize_exprs(self.constants.clone()); + // Prune redundant sections in the requirement: + collapse_lex_req( + normalized_sort_reqs + .iter() + .filter(|&order| { + !physical_exprs_contains(&constants_normalized, &order.expr) + }) + .cloned() + .collect(), + ) + } + + /// Checks whether the given ordering is satisfied by any of the existing + /// orderings. + pub fn ordering_satisfy(&self, given: LexOrderingRef) -> bool { + // Convert the given sort expressions to sort requirements: + let sort_requirements = PhysicalSortRequirement::from_sort_exprs(given.iter()); + self.ordering_satisfy_requirement(&sort_requirements) + } + + /// Checks whether the given sort requirements are satisfied by any of the + /// existing orderings. + pub fn ordering_satisfy_requirement(&self, reqs: LexRequirementRef) -> bool { + let mut eq_properties = self.clone(); + // First, standardize the given requirement: + let normalized_reqs = eq_properties.normalize_sort_requirements(reqs); + for normalized_req in normalized_reqs { + // Check whether given ordering is satisfied + if !eq_properties.ordering_satisfy_single(&normalized_req) { + return false; + } + // Treat satisfied keys as constants in subsequent iterations. We + // can do this because the "next" key only matters in a lexicographical + // ordering when the keys to its left have the same values. + // + // Note that these expressions are not properly "constants". This is just + // an implementation strategy confined to this function. + // + // For example, assume that the requirement is `[a ASC, (b + c) ASC]`, + // and existing equivalent orderings are `[a ASC, b ASC]` and `[c ASC]`. + // From the analysis above, we know that `[a ASC]` is satisfied. Then, + // we add column `a` as constant to the algorithm state. This enables us + // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. + eq_properties = + eq_properties.add_constants(std::iter::once(normalized_req.expr)); + } + true + } + + /// Determines whether the ordering specified by the given sort requirement + /// is satisfied based on the orderings within, equivalence classes, and + /// constant expressions. + /// + /// # Arguments + /// + /// - `req`: A reference to a `PhysicalSortRequirement` for which the ordering + /// satisfaction check will be done. + /// + /// # Returns + /// + /// Returns `true` if the specified ordering is satisfied, `false` otherwise. + fn ordering_satisfy_single(&self, req: &PhysicalSortRequirement) -> bool { + let expr_ordering = self.get_expr_ordering(req.expr.clone()); + let ExprOrdering { expr, state, .. } = expr_ordering; + match state { + SortProperties::Ordered(options) => { + let sort_expr = PhysicalSortExpr { expr, options }; + sort_expr.satisfy(req, self.schema()) + } + // Singleton expressions satisfies any ordering. + SortProperties::Singleton => true, + SortProperties::Unordered => false, + } + } + + /// Checks whether the `given`` sort requirements are equal or more specific + /// than the `reference` sort requirements. + pub fn requirements_compatible( + &self, + given: LexRequirementRef, + reference: LexRequirementRef, + ) -> bool { + let normalized_given = self.normalize_sort_requirements(given); + let normalized_reference = self.normalize_sort_requirements(reference); + + (normalized_reference.len() <= normalized_given.len()) + && normalized_reference + .into_iter() + .zip(normalized_given) + .all(|(reference, given)| given.compatible(&reference)) + } + + /// Returns the finer ordering among the orderings `lhs` and `rhs`, breaking + /// any ties by choosing `lhs`. + /// + /// The finer ordering is the ordering that satisfies both of the orderings. + /// If the orderings are incomparable, returns `None`. + /// + /// For example, the finer ordering among `[a ASC]` and `[a ASC, b ASC]` is + /// the latter. + pub fn get_finer_ordering( + &self, + lhs: LexOrderingRef, + rhs: LexOrderingRef, + ) -> Option { + // Convert the given sort expressions to sort requirements: + let lhs = PhysicalSortRequirement::from_sort_exprs(lhs); + let rhs = PhysicalSortRequirement::from_sort_exprs(rhs); + let finer = self.get_finer_requirement(&lhs, &rhs); + // Convert the chosen sort requirements back to sort expressions: + finer.map(PhysicalSortRequirement::to_sort_exprs) + } + + /// Returns the finer ordering among the requirements `lhs` and `rhs`, + /// breaking any ties by choosing `lhs`. + /// + /// The finer requirements are the ones that satisfy both of the given + /// requirements. If the requirements are incomparable, returns `None`. + /// + /// For example, the finer requirements among `[a ASC]` and `[a ASC, b ASC]` + /// is the latter. + pub fn get_finer_requirement( + &self, + req1: LexRequirementRef, + req2: LexRequirementRef, + ) -> Option { + let mut lhs = self.normalize_sort_requirements(req1); + let mut rhs = self.normalize_sort_requirements(req2); + lhs.iter_mut() + .zip(rhs.iter_mut()) + .all(|(lhs, rhs)| { + lhs.expr.eq(&rhs.expr) + && match (lhs.options, rhs.options) { + (Some(lhs_opt), Some(rhs_opt)) => lhs_opt == rhs_opt, + (Some(options), None) => { + rhs.options = Some(options); + true + } + (None, Some(options)) => { + lhs.options = Some(options); + true + } + (None, None) => true, + } + }) + .then_some(if lhs.len() >= rhs.len() { lhs } else { rhs }) + } + + /// Calculates the "meet" of the given orderings (`lhs` and `rhs`). + /// The meet of a set of orderings is the finest ordering that is satisfied + /// by all the orderings in that set. For details, see: + /// + /// + /// + /// If there is no ordering that satisfies both `lhs` and `rhs`, returns + /// `None`. As an example, the meet of orderings `[a ASC]` and `[a ASC, b ASC]` + /// is `[a ASC]`. + pub fn get_meet_ordering( + &self, + lhs: LexOrderingRef, + rhs: LexOrderingRef, + ) -> Option { + let lhs = self.normalize_sort_exprs(lhs); + let rhs = self.normalize_sort_exprs(rhs); + let mut meet = vec![]; + for (lhs, rhs) in lhs.into_iter().zip(rhs.into_iter()) { + if lhs.eq(&rhs) { + meet.push(lhs); + } else { + break; + } + } + (!meet.is_empty()).then_some(meet) + } + + /// Projects argument `expr` according to `projection_mapping`, taking + /// equivalences into account. + /// + /// For example, assume that columns `a` and `c` are always equal, and that + /// `projection_mapping` encodes following mapping: + /// + /// ```text + /// a -> a1 + /// b -> b1 + /// ``` + /// + /// Then, this function projects `a + b` to `Some(a1 + b1)`, `c + b` to + /// `Some(a1 + b1)` and `d` to `None`, meaning that it cannot be projected. + pub fn project_expr( + &self, + expr: &Arc, + projection_mapping: &ProjectionMapping, + ) -> Option> { + self.eq_group.project_expr(projection_mapping, expr) + } + + /// Constructs a dependency map based on existing orderings referred to in + /// the projection. + /// + /// This function analyzes the orderings in the normalized order-equivalence + /// class and builds a dependency map. The dependency map captures relationships + /// between expressions within the orderings, helping to identify dependencies + /// and construct valid projected orderings during projection operations. + /// + /// # Parameters + /// + /// - `mapping`: A reference to the `ProjectionMapping` that defines the + /// relationship between source and target expressions. + /// + /// # Returns + /// + /// A [`DependencyMap`] representing the dependency map, where each + /// [`DependencyNode`] contains dependencies for the key [`PhysicalSortExpr`]. + /// + /// # Example + /// + /// Assume we have two equivalent orderings: `[a ASC, b ASC]` and `[a ASC, c ASC]`, + /// and the projection mapping is `[a -> a_new, b -> b_new, b + c -> b + c]`. + /// Then, the dependency map will be: + /// + /// ```text + /// a ASC: Node {Some(a_new ASC), HashSet{}} + /// b ASC: Node {Some(b_new ASC), HashSet{a ASC}} + /// c ASC: Node {None, HashSet{a ASC}} + /// ``` + fn construct_dependency_map(&self, mapping: &ProjectionMapping) -> DependencyMap { + let mut dependency_map = HashMap::new(); + for ordering in self.normalized_oeq_class().iter() { + for (idx, sort_expr) in ordering.iter().enumerate() { + let target_sort_expr = + self.project_expr(&sort_expr.expr, mapping).map(|expr| { + PhysicalSortExpr { + expr, + options: sort_expr.options, + } + }); + let is_projected = target_sort_expr.is_some(); + if is_projected + || mapping + .iter() + .any(|(source, _)| expr_refers(source, &sort_expr.expr)) + { + // Previous ordering is a dependency. Note that there is no, + // dependency for a leading ordering (i.e. the first sort + // expression). + let dependency = idx.checked_sub(1).map(|a| &ordering[a]); + // Add sort expressions that can be projected or referred to + // by any of the projection expressions to the dependency map: + dependency_map + .entry(sort_expr.clone()) + .or_insert_with(|| DependencyNode { + target_sort_expr: target_sort_expr.clone(), + dependencies: HashSet::new(), + }) + .insert_dependency(dependency); + } + if !is_projected { + // If we can not project, stop constructing the dependency + // map as remaining dependencies will be invalid after projection. + break; + } + } + } + dependency_map + } + + /// Returns a new `ProjectionMapping` where source expressions are normalized. + /// + /// This normalization ensures that source expressions are transformed into a + /// consistent representation. This is beneficial for algorithms that rely on + /// exact equalities, as it allows for more precise and reliable comparisons. + /// + /// # Parameters + /// + /// - `mapping`: A reference to the original `ProjectionMapping` to be normalized. + /// + /// # Returns + /// + /// A new `ProjectionMapping` with normalized source expressions. + fn normalized_mapping(&self, mapping: &ProjectionMapping) -> ProjectionMapping { + // Construct the mapping where source expressions are normalized. In this way + // In the algorithms below we can work on exact equalities + ProjectionMapping { + map: mapping + .iter() + .map(|(source, target)| { + let normalized_source = self.eq_group.normalize_expr(source.clone()); + (normalized_source, target.clone()) + }) + .collect(), + } + } + + /// Computes projected orderings based on a given projection mapping. + /// + /// This function takes a `ProjectionMapping` and computes the possible + /// orderings for the projected expressions. It considers dependencies + /// between expressions and generates valid orderings according to the + /// specified sort properties. + /// + /// # Parameters + /// + /// - `mapping`: A reference to the `ProjectionMapping` that defines the + /// relationship between source and target expressions. + /// + /// # Returns + /// + /// A vector of `LexOrdering` containing all valid orderings after projection. + fn projected_orderings(&self, mapping: &ProjectionMapping) -> Vec { + let mapping = self.normalized_mapping(mapping); + + // Get dependency map for existing orderings: + let dependency_map = self.construct_dependency_map(&mapping); + + let orderings = mapping.iter().flat_map(|(source, target)| { + referred_dependencies(&dependency_map, source) + .into_iter() + .filter_map(|relevant_deps| { + if let SortProperties::Ordered(options) = + get_expr_ordering(source, &relevant_deps) + { + Some((options, relevant_deps)) + } else { + // Do not consider unordered cases + None + } + }) + .flat_map(|(options, relevant_deps)| { + let sort_expr = PhysicalSortExpr { + expr: target.clone(), + options, + }; + // Generate dependent orderings (i.e. prefixes for `sort_expr`): + let mut dependency_orderings = + generate_dependency_orderings(&relevant_deps, &dependency_map); + // Append `sort_expr` to the dependent orderings: + for ordering in dependency_orderings.iter_mut() { + ordering.push(sort_expr.clone()); + } + dependency_orderings + }) + }); + + // Add valid projected orderings. For example, if existing ordering is + // `a + b` and projection is `[a -> a_new, b -> b_new]`, we need to + // preserve `a_new + b_new` as ordered. Please note that `a_new` and + // `b_new` themselves need not be ordered. Such dependencies cannot be + // deduced via the pass above. + let projected_orderings = dependency_map.iter().flat_map(|(sort_expr, node)| { + let mut prefixes = construct_prefix_orderings(sort_expr, &dependency_map); + if prefixes.is_empty() { + // If prefix is empty, there is no dependency. Insert + // empty ordering: + prefixes = vec![vec![]]; + } + // Append current ordering on top its dependencies: + for ordering in prefixes.iter_mut() { + if let Some(target) = &node.target_sort_expr { + ordering.push(target.clone()) + } + } + prefixes + }); + + // Simplify each ordering by removing redundant sections: + orderings + .chain(projected_orderings) + .map(collapse_lex_ordering) + .collect() + } + + /// Projects constants based on the provided `ProjectionMapping`. + /// + /// This function takes a `ProjectionMapping` and identifies/projects + /// constants based on the existing constants and the mapping. It ensures + /// that constants are appropriately propagated through the projection. + /// + /// # Arguments + /// + /// - `mapping`: A reference to a `ProjectionMapping` representing the + /// mapping of source expressions to target expressions in the projection. + /// + /// # Returns + /// + /// Returns a `Vec>` containing the projected constants. + fn projected_constants( + &self, + mapping: &ProjectionMapping, + ) -> Vec> { + // First, project existing constants. For example, assume that `a + b` + // is known to be constant. If the projection were `a as a_new`, `b as b_new`, + // then we would project constant `a + b` as `a_new + b_new`. + let mut projected_constants = self + .constants + .iter() + .flat_map(|expr| self.eq_group.project_expr(mapping, expr)) + .collect::>(); + // Add projection expressions that are known to be constant: + for (source, target) in mapping.iter() { + if self.is_expr_constant(source) + && !physical_exprs_contains(&projected_constants, target) + { + projected_constants.push(target.clone()); + } + } + projected_constants + } + + /// Projects the equivalences within according to `projection_mapping` + /// and `output_schema`. + pub fn project( + &self, + projection_mapping: &ProjectionMapping, + output_schema: SchemaRef, + ) -> Self { + let projected_constants = self.projected_constants(projection_mapping); + let projected_eq_group = self.eq_group.project(projection_mapping); + let projected_orderings = self.projected_orderings(projection_mapping); + Self { + eq_group: projected_eq_group, + oeq_class: OrderingEquivalenceClass::new(projected_orderings), + constants: projected_constants, + schema: output_schema, + } + } + + /// Returns the longest (potentially partial) permutation satisfying the + /// existing ordering. For example, if we have the equivalent orderings + /// `[a ASC, b ASC]` and `[c DESC]`, with `exprs` containing `[c, b, a, d]`, + /// then this function returns `([a ASC, b ASC, c DESC], [2, 1, 0])`. + /// This means that the specification `[a ASC, b ASC, c DESC]` is satisfied + /// by the existing ordering, and `[a, b, c]` resides at indices: `2, 1, 0` + /// inside the argument `exprs` (respectively). For the mathematical + /// definition of "partial permutation", see: + /// + /// + pub fn find_longest_permutation( + &self, + exprs: &[Arc], + ) -> (LexOrdering, Vec) { + let mut eq_properties = self.clone(); + let mut result = vec![]; + // The algorithm is as follows: + // - Iterate over all the expressions and insert ordered expressions + // into the result. + // - Treat inserted expressions as constants (i.e. add them as constants + // to the state). + // - Continue the above procedure until no expression is inserted; i.e. + // the algorithm reaches a fixed point. + // This algorithm should reach a fixed point in at most `exprs.len()` + // iterations. + let mut search_indices = (0..exprs.len()).collect::>(); + for _idx in 0..exprs.len() { + // Get ordered expressions with their indices. + let ordered_exprs = search_indices + .iter() + .flat_map(|&idx| { + let ExprOrdering { expr, state, .. } = + eq_properties.get_expr_ordering(exprs[idx].clone()); + if let SortProperties::Ordered(options) = state { + Some((PhysicalSortExpr { expr, options }, idx)) + } else { + None + } + }) + .collect::>(); + // We reached a fixed point, exit. + if ordered_exprs.is_empty() { + break; + } + // Remove indices that have an ordering from `search_indices`, and + // treat ordered expressions as constants in subsequent iterations. + // We can do this because the "next" key only matters in a lexicographical + // ordering when the keys to its left have the same values. + // + // Note that these expressions are not properly "constants". This is just + // an implementation strategy confined to this function. + for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { + eq_properties = + eq_properties.add_constants(std::iter::once(expr.clone())); + search_indices.remove(idx); + } + // Add new ordered section to the state. + result.extend(ordered_exprs); + } + result.into_iter().unzip() + } + + /// This function determines whether the provided expression is constant + /// based on the known constants. + /// + /// # Arguments + /// + /// - `expr`: A reference to a `Arc` representing the + /// expression to be checked. + /// + /// # Returns + /// + /// Returns `true` if the expression is constant according to equivalence + /// group, `false` otherwise. + fn is_expr_constant(&self, expr: &Arc) -> bool { + // As an example, assume that we know columns `a` and `b` are constant. + // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will + // return `false`. + let normalized_constants = self.eq_group.normalize_exprs(self.constants.to_vec()); + let normalized_expr = self.eq_group.normalize_expr(expr.clone()); + is_constant_recurse(&normalized_constants, &normalized_expr) + } + + /// Retrieves the ordering information for a given physical expression. + /// + /// This function constructs an `ExprOrdering` object for the provided + /// expression, which encapsulates information about the expression's + /// ordering, including its [`SortProperties`]. + /// + /// # Arguments + /// + /// - `expr`: An `Arc` representing the physical expression + /// for which ordering information is sought. + /// + /// # Returns + /// + /// Returns an `ExprOrdering` object containing the ordering information for + /// the given expression. + pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { + ExprOrdering::new(expr.clone()) + .transform_up(&|expr| Ok(update_ordering(expr, self))) + // Guaranteed to always return `Ok`. + .unwrap() + } +} + +/// Calculates the [`SortProperties`] of a given [`ExprOrdering`] node. +/// The node can either be a leaf node, or an intermediate node: +/// - If it is a leaf node, we directly find the order of the node by looking +/// at the given sort expression and equivalence properties if it is a `Column` +/// leaf, or we mark it as unordered. In the case of a `Literal` leaf, we mark +/// it as singleton so that it can cooperate with all ordered columns. +/// - If it is an intermediate node, the children states matter. Each `PhysicalExpr` +/// and operator has its own rules on how to propagate the children orderings. +/// However, before we engage in recursion, we check whether this intermediate +/// node directly matches with the sort expression. If there is a match, the +/// sort expression emerges at that node immediately, discarding the recursive +/// result coming from its children. +fn update_ordering( + mut node: ExprOrdering, + eq_properties: &EquivalenceProperties, +) -> Transformed { + // We have a Column, which is one of the two possible leaf node types: + let normalized_expr = eq_properties.eq_group.normalize_expr(node.expr.clone()); + if eq_properties.is_expr_constant(&normalized_expr) { + node.state = SortProperties::Singleton; + } else if let Some(options) = eq_properties + .normalized_oeq_class() + .get_options(&normalized_expr) + { + node.state = SortProperties::Ordered(options); + } else if !node.expr.children().is_empty() { + // We have an intermediate (non-leaf) node, account for its children: + node.state = node.expr.get_ordering(&node.children_state()); + } else if node.expr.as_any().is::() { + // We have a Literal, which is the other possible leaf node type: + node.state = node.expr.get_ordering(&[]); + } else { + return Transformed::No(node); + } + Transformed::Yes(node) +} + +/// This function determines whether the provided expression is constant +/// based on the known constants. +/// +/// # Arguments +/// +/// - `constants`: A `&[Arc]` containing expressions known to +/// be a constant. +/// - `expr`: A reference to a `Arc` representing the expression +/// to check. +/// +/// # Returns +/// +/// Returns `true` if the expression is constant according to equivalence +/// group, `false` otherwise. +fn is_constant_recurse( + constants: &[Arc], + expr: &Arc, +) -> bool { + if physical_exprs_contains(constants, expr) { + return true; + } + let children = expr.children(); + !children.is_empty() && children.iter().all(|c| is_constant_recurse(constants, c)) +} + +/// This function examines whether a referring expression directly refers to a +/// given referred expression or if any of its children in the expression tree +/// refer to the specified expression. +/// +/// # Parameters +/// +/// - `referring_expr`: A reference to the referring expression (`Arc`). +/// - `referred_expr`: A reference to the referred expression (`Arc`) +/// +/// # Returns +/// +/// A boolean value indicating whether `referring_expr` refers (needs it to evaluate its result) +/// `referred_expr` or not. +fn expr_refers( + referring_expr: &Arc, + referred_expr: &Arc, +) -> bool { + referring_expr.eq(referred_expr) + || referring_expr + .children() + .iter() + .any(|child| expr_refers(child, referred_expr)) +} + +/// This function analyzes the dependency map to collect referred dependencies for +/// a given source expression. +/// +/// # Parameters +/// +/// - `dependency_map`: A reference to the `DependencyMap` where each +/// `PhysicalSortExpr` is associated with a `DependencyNode`. +/// - `source`: A reference to the source expression (`Arc`) +/// for which relevant dependencies need to be identified. +/// +/// # Returns +/// +/// A `Vec` containing the dependencies for the given source +/// expression. These dependencies are expressions that are referred to by +/// the source expression based on the provided dependency map. +fn referred_dependencies( + dependency_map: &DependencyMap, + source: &Arc, +) -> Vec { + // Associate `PhysicalExpr`s with `PhysicalSortExpr`s that contain them: + let mut expr_to_sort_exprs = HashMap::::new(); + for sort_expr in dependency_map + .keys() + .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) + { + let key = ExprWrapper(sort_expr.expr.clone()); + expr_to_sort_exprs + .entry(key) + .or_default() + .insert(sort_expr.clone()); + } + + // Generate all valid dependencies for the source. For example, if the source + // is `a + b` and the map is `[a -> (a ASC, a DESC), b -> (b ASC)]`, we get + // `vec![HashSet(a ASC, b ASC), HashSet(a DESC, b ASC)]`. + expr_to_sort_exprs + .values() + .multi_cartesian_product() + .map(|referred_deps| referred_deps.into_iter().cloned().collect()) + .collect() +} + +/// This function retrieves the dependencies of the given relevant sort expression +/// from the given dependency map. It then constructs prefix orderings by recursively +/// analyzing the dependencies and include them in the orderings. +/// +/// # Parameters +/// +/// - `relevant_sort_expr`: A reference to the relevant sort expression +/// (`PhysicalSortExpr`) for which prefix orderings are to be constructed. +/// - `dependency_map`: A reference to the `DependencyMap` containing dependencies. +/// +/// # Returns +/// +/// A vector of prefix orderings (`Vec`) based on the given relevant +/// sort expression and its dependencies. +fn construct_prefix_orderings( + relevant_sort_expr: &PhysicalSortExpr, + dependency_map: &DependencyMap, +) -> Vec { + dependency_map[relevant_sort_expr] + .dependencies + .iter() + .flat_map(|dep| construct_orderings(dep, dependency_map)) + .collect() +} + +/// Given a set of relevant dependencies (`relevant_deps`) and a map of dependencies +/// (`dependency_map`), this function generates all possible prefix orderings +/// based on the given dependencies. +/// +/// # Parameters +/// +/// * `dependencies` - A reference to the dependencies. +/// * `dependency_map` - A reference to the map of dependencies for expressions. +/// +/// # Returns +/// +/// A vector of lexical orderings (`Vec`) representing all valid orderings +/// based on the given dependencies. +fn generate_dependency_orderings( + dependencies: &Dependencies, + dependency_map: &DependencyMap, +) -> Vec { + // Construct all the valid prefix orderings for each expression appearing + // in the projection: + let relevant_prefixes = dependencies + .iter() + .flat_map(|dep| { + let prefixes = construct_prefix_orderings(dep, dependency_map); + (!prefixes.is_empty()).then_some(prefixes) + }) + .collect::>(); + + // No dependency, dependent is a leading ordering. + if relevant_prefixes.is_empty() { + // Return an empty ordering: + return vec![vec![]]; + } + + // Generate all possible orderings where dependencies are satisfied for the + // current projection expression. For example, if expression is `a + b ASC`, + // and the dependency for `a ASC` is `[c ASC]`, the dependency for `b ASC` + // is `[d DESC]`, then we generate `[c ASC, d DESC, a + b ASC]` and + // `[d DESC, c ASC, a + b ASC]`. + relevant_prefixes + .into_iter() + .multi_cartesian_product() + .flat_map(|prefix_orderings| { + prefix_orderings + .iter() + .permutations(prefix_orderings.len()) + .map(|prefixes| prefixes.into_iter().flatten().cloned().collect()) + .collect::>() + }) + .collect() +} + +/// This function examines the given expression and the sort expressions it +/// refers to determine the ordering properties of the expression. +/// +/// # Parameters +/// +/// - `expr`: A reference to the source expression (`Arc`) for +/// which ordering properties need to be determined. +/// - `dependencies`: A reference to `Dependencies`, containing sort expressions +/// referred to by `expr`. +/// +/// # Returns +/// +/// A `SortProperties` indicating the ordering information of the given expression. +fn get_expr_ordering( + expr: &Arc, + dependencies: &Dependencies, +) -> SortProperties { + if let Some(column_order) = dependencies.iter().find(|&order| expr.eq(&order.expr)) { + // If exact match is found, return its ordering. + SortProperties::Ordered(column_order.options) + } else { + // Find orderings of its children + let child_states = expr + .children() + .iter() + .map(|child| get_expr_ordering(child, dependencies)) + .collect::>(); + // Calculate expression ordering using ordering of its children. + expr.get_ordering(&child_states) + } +} + +/// Represents a node in the dependency map used to construct projected orderings. +/// +/// A `DependencyNode` contains information about a particular sort expression, +/// including its target sort expression and a set of dependencies on other sort +/// expressions. +/// +/// # Fields +/// +/// - `target_sort_expr`: An optional `PhysicalSortExpr` representing the target +/// sort expression associated with the node. It is `None` if the sort expression +/// cannot be projected. +/// - `dependencies`: A [`Dependencies`] containing dependencies on other sort +/// expressions that are referred to by the target sort expression. +#[derive(Debug, Clone, PartialEq, Eq)] +struct DependencyNode { + target_sort_expr: Option, + dependencies: Dependencies, +} + +impl DependencyNode { + // Insert dependency to the state (if exists). + fn insert_dependency(&mut self, dependency: Option<&PhysicalSortExpr>) { + if let Some(dep) = dependency { + self.dependencies.insert(dep.clone()); + } + } +} + +type DependencyMap = HashMap; +type Dependencies = HashSet; + +/// This function recursively analyzes the dependencies of the given sort +/// expression within the given dependency map to construct lexicographical +/// orderings that include the sort expression and its dependencies. +/// +/// # Parameters +/// +/// - `referred_sort_expr`: A reference to the sort expression (`PhysicalSortExpr`) +/// for which lexicographical orderings satisfying its dependencies are to be +/// constructed. +/// - `dependency_map`: A reference to the `DependencyMap` that contains +/// dependencies for different `PhysicalSortExpr`s. +/// +/// # Returns +/// +/// A vector of lexicographical orderings (`Vec`) based on the given +/// sort expression and its dependencies. +fn construct_orderings( + referred_sort_expr: &PhysicalSortExpr, + dependency_map: &DependencyMap, +) -> Vec { + // We are sure that `referred_sort_expr` is inside `dependency_map`. + let node = &dependency_map[referred_sort_expr]; + // Since we work on intermediate nodes, we are sure `val.target_sort_expr` + // exists. + let target_sort_expr = node.target_sort_expr.clone().unwrap(); + if node.dependencies.is_empty() { + vec![vec![target_sort_expr]] + } else { + node.dependencies + .iter() + .flat_map(|dep| { + let mut orderings = construct_orderings(dep, dependency_map); + for ordering in orderings.iter_mut() { + ordering.push(target_sort_expr.clone()) + } + orderings + }) + .collect() + } +} + +/// Calculate ordering equivalence properties for the given join operation. +pub fn join_equivalence_properties( + left: EquivalenceProperties, + right: EquivalenceProperties, + join_type: &JoinType, + join_schema: SchemaRef, + maintains_input_order: &[bool], + probe_side: Option, + on: &[(Column, Column)], +) -> EquivalenceProperties { + let left_size = left.schema.fields.len(); + let mut result = EquivalenceProperties::new(join_schema); + result.add_equivalence_group(left.eq_group().join( + right.eq_group(), + join_type, + left_size, + on, + )); + + let left_oeq_class = left.oeq_class; + let mut right_oeq_class = right.oeq_class; + match maintains_input_order { + [true, false] => { + // In this special case, right side ordering can be prefixed with + // the left side ordering. + if let (Some(JoinSide::Left), JoinType::Inner) = (probe_side, join_type) { + updated_right_ordering_equivalence_class( + &mut right_oeq_class, + join_type, + left_size, + ); + + // Right side ordering equivalence properties should be prepended + // with those of the left side while constructing output ordering + // equivalence properties since stream side is the left side. + // + // For example, if the right side ordering equivalences contain + // `b ASC`, and the left side ordering equivalences contain `a ASC`, + // then we should add `a ASC, b ASC` to the ordering equivalences + // of the join output. + let out_oeq_class = left_oeq_class.join_suffix(&right_oeq_class); + result.add_ordering_equivalence_class(out_oeq_class); + } else { + result.add_ordering_equivalence_class(left_oeq_class); + } + } + [false, true] => { + updated_right_ordering_equivalence_class( + &mut right_oeq_class, + join_type, + left_size, + ); + // In this special case, left side ordering can be prefixed with + // the right side ordering. + if let (Some(JoinSide::Right), JoinType::Inner) = (probe_side, join_type) { + // Left side ordering equivalence properties should be prepended + // with those of the right side while constructing output ordering + // equivalence properties since stream side is the right side. + // + // For example, if the left side ordering equivalences contain + // `a ASC`, and the right side ordering equivalences contain `b ASC`, + // then we should add `b ASC, a ASC` to the ordering equivalences + // of the join output. + let out_oeq_class = right_oeq_class.join_suffix(&left_oeq_class); + result.add_ordering_equivalence_class(out_oeq_class); + } else { + result.add_ordering_equivalence_class(right_oeq_class); + } + } + [false, false] => {} + [true, true] => unreachable!("Cannot maintain ordering of both sides"), + _ => unreachable!("Join operators can not have more than two children"), + } + result +} + +/// In the context of a join, update the right side `OrderingEquivalenceClass` +/// so that they point to valid indices in the join output schema. +/// +/// To do so, we increment column indices by the size of the left table when +/// join schema consists of a combination of the left and right schemas. This +/// is the case for `Inner`, `Left`, `Full` and `Right` joins. For other cases, +/// indices do not change. +fn updated_right_ordering_equivalence_class( + right_oeq_class: &mut OrderingEquivalenceClass, + join_type: &JoinType, + left_size: usize, +) { + if matches!( + join_type, + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right + ) { + right_oeq_class.add_offset(left_size); + } +} + +/// Wrapper struct for `Arc` to use them as keys in a hash map. +#[derive(Debug, Clone)] +struct ExprWrapper(Arc); + +impl PartialEq for ExprWrapper { + fn eq(&self, other: &Self) -> bool { + self.0.eq(&other.0) + } +} + +impl Eq for ExprWrapper {} + +impl Hash for ExprWrapper { + fn hash(&self, state: &mut H) { + self.0.hash(state); + } +} + +#[cfg(test)] +mod tests { + use std::ops::Not; + use std::sync::Arc; + + use super::*; + use crate::equivalence::add_offset_to_expr; + use crate::equivalence::tests::{ + convert_to_orderings, convert_to_sort_exprs, convert_to_sort_reqs, + create_random_schema, create_test_params, create_test_schema, + generate_table_for_eq_properties, is_table_same_after_sort, output_schema, + }; + use crate::execution_props::ExecutionProps; + use crate::expressions::{col, BinaryExpr, Column}; + use crate::functions::create_physical_expr; + use crate::PhysicalSortExpr; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{Fields, SortOptions, TimeUnit}; + use datafusion_common::Result; + use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; + + #[test] + fn project_equivalence_properties_test() -> Result<()> { + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + ])); + + let input_properties = EquivalenceProperties::new(input_schema.clone()); + let col_a = col("a", &input_schema)?; + + // a as a1, a as a2, a as a3, a as a3 + let proj_exprs = vec![ + (col_a.clone(), "a1".to_string()), + (col_a.clone(), "a2".to_string()), + (col_a.clone(), "a3".to_string()), + (col_a.clone(), "a4".to_string()), + ]; + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + let out_schema = output_schema(&projection_mapping, &input_schema)?; + // a as a1, a as a2, a as a3, a as a3 + let proj_exprs = vec![ + (col_a.clone(), "a1".to_string()), + (col_a.clone(), "a2".to_string()), + (col_a.clone(), "a3".to_string()), + (col_a.clone(), "a4".to_string()), + ]; + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + // a as a1, a as a2, a as a3, a as a3 + let col_a1 = &col("a1", &out_schema)?; + let col_a2 = &col("a2", &out_schema)?; + let col_a3 = &col("a3", &out_schema)?; + let col_a4 = &col("a4", &out_schema)?; + let out_properties = input_properties.project(&projection_mapping, out_schema); + + // At the output a1=a2=a3=a4 + assert_eq!(out_properties.eq_group().len(), 1); + let eq_class = &out_properties.eq_group().classes[0]; + assert_eq!(eq_class.len(), 4); + assert!(eq_class.contains(col_a1)); + assert!(eq_class.contains(col_a2)); + assert!(eq_class.contains(col_a3)); + assert!(eq_class.contains(col_a4)); + + Ok(()) + } + + #[test] + fn test_join_equivalence_properties() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let offset = schema.fields.len(); + let col_a2 = &add_offset_to_expr(col_a.clone(), offset); + let col_b2 = &add_offset_to_expr(col_b.clone(), offset); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let test_cases = vec![ + // ------- TEST CASE 1 -------- + // [a ASC], [b ASC] + ( + // [a ASC], [b ASC] + vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], + // [a ASC], [b ASC] + vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], + // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC] + vec![ + vec![(col_a, option_asc), (col_a2, option_asc)], + vec![(col_a, option_asc), (col_b2, option_asc)], + vec![(col_b, option_asc), (col_a2, option_asc)], + vec![(col_b, option_asc), (col_b2, option_asc)], + ], + ), + // ------- TEST CASE 2 -------- + // [a ASC], [b ASC] + ( + // [a ASC], [b ASC], [c ASC] + vec![ + vec![(col_a, option_asc)], + vec![(col_b, option_asc)], + vec![(col_c, option_asc)], + ], + // [a ASC], [b ASC] + vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], + // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC], [c ASC, a2 ASC], [c ASC, b2 ASC] + vec![ + vec![(col_a, option_asc), (col_a2, option_asc)], + vec![(col_a, option_asc), (col_b2, option_asc)], + vec![(col_b, option_asc), (col_a2, option_asc)], + vec![(col_b, option_asc), (col_b2, option_asc)], + vec![(col_c, option_asc), (col_a2, option_asc)], + vec![(col_c, option_asc), (col_b2, option_asc)], + ], + ), + ]; + for (left_orderings, right_orderings, expected) in test_cases { + let mut left_eq_properties = EquivalenceProperties::new(schema.clone()); + let mut right_eq_properties = EquivalenceProperties::new(schema.clone()); + let left_orderings = convert_to_orderings(&left_orderings); + let right_orderings = convert_to_orderings(&right_orderings); + let expected = convert_to_orderings(&expected); + left_eq_properties.add_new_orderings(left_orderings); + right_eq_properties.add_new_orderings(right_orderings); + let join_eq = join_equivalence_properties( + left_eq_properties, + right_eq_properties, + &JoinType::Inner, + Arc::new(Schema::empty()), + &[true, false], + Some(JoinSide::Left), + &[], + ); + let orderings = &join_eq.oeq_class.orderings; + let err_msg = format!("expected: {:?}, actual:{:?}", expected, orderings); + assert_eq!( + join_eq.oeq_class.orderings.len(), + expected.len(), + "{}", + err_msg + ); + for ordering in orderings { + assert!( + expected.contains(ordering), + "{}, ordering: {:?}", + err_msg, + ordering + ); + } + } + Ok(()) + } + + #[test] + fn test_expr_consists_of_constants() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_d = col("d", &schema)?; + let b_plus_d = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let constants = vec![col_a.clone(), col_b.clone()]; + let expr = b_plus_d.clone(); + assert!(!is_constant_recurse(&constants, &expr)); + + let constants = vec![col_a.clone(), col_b.clone(), col_d.clone()]; + let expr = b_plus_d.clone(); + assert!(is_constant_recurse(&constants, &expr)); + Ok(()) + } + + #[test] + fn test_get_updated_right_ordering_equivalence_properties() -> Result<()> { + let join_type = JoinType::Inner; + // Join right child schema + let child_fields: Fields = ["x", "y", "z", "w"] + .into_iter() + .map(|name| Field::new(name, DataType::Int32, true)) + .collect(); + let child_schema = Schema::new(child_fields); + let col_x = &col("x", &child_schema)?; + let col_y = &col("y", &child_schema)?; + let col_z = &col("z", &child_schema)?; + let col_w = &col("w", &child_schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + // [x ASC, y ASC], [z ASC, w ASC] + let orderings = vec![ + vec![(col_x, option_asc), (col_y, option_asc)], + vec![(col_z, option_asc), (col_w, option_asc)], + ]; + let orderings = convert_to_orderings(&orderings); + // Right child ordering equivalences + let mut right_oeq_class = OrderingEquivalenceClass::new(orderings); + + let left_columns_len = 4; + + let fields: Fields = ["a", "b", "c", "d", "x", "y", "z", "w"] + .into_iter() + .map(|name| Field::new(name, DataType::Int32, true)) + .collect(); + + // Join Schema + let schema = Schema::new(fields); + let col_a = &col("a", &schema)?; + let col_d = &col("d", &schema)?; + let col_x = &col("x", &schema)?; + let col_y = &col("y", &schema)?; + let col_z = &col("z", &schema)?; + let col_w = &col("w", &schema)?; + + let mut join_eq_properties = EquivalenceProperties::new(Arc::new(schema)); + // a=x and d=w + join_eq_properties.add_equal_conditions(col_a, col_x); + join_eq_properties.add_equal_conditions(col_d, col_w); + + updated_right_ordering_equivalence_class( + &mut right_oeq_class, + &join_type, + left_columns_len, + ); + join_eq_properties.add_ordering_equivalence_class(right_oeq_class); + let result = join_eq_properties.oeq_class().clone(); + + // [x ASC, y ASC], [z ASC, w ASC] + let orderings = vec![ + vec![(col_x, option_asc), (col_y, option_asc)], + vec![(col_z, option_asc), (col_w, option_asc)], + ]; + let orderings = convert_to_orderings(&orderings); + let expected = OrderingEquivalenceClass::new(orderings); + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn test_normalize_ordering_equivalence_classes() -> Result<()> { + let sort_options = SortOptions::default(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let col_a_expr = col("a", &schema)?; + let col_b_expr = col("b", &schema)?; + let col_c_expr = col("c", &schema)?; + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); + + eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr); + let others = vec![ + vec![PhysicalSortExpr { + expr: col_b_expr.clone(), + options: sort_options, + }], + vec![PhysicalSortExpr { + expr: col_c_expr.clone(), + options: sort_options, + }], + ]; + eq_properties.add_new_orderings(others); + + let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema)); + expected_eqs.add_new_orderings([ + vec![PhysicalSortExpr { + expr: col_b_expr.clone(), + options: sort_options, + }], + vec![PhysicalSortExpr { + expr: col_c_expr.clone(), + options: sort_options, + }], + ]); + + let oeq_class = eq_properties.oeq_class().clone(); + let expected = expected_eqs.oeq_class(); + assert!(oeq_class.eq(expected)); + + Ok(()) + } + + #[test] + fn test_get_indices_of_matching_sort_exprs_with_order_eq() -> Result<()> { + let sort_options = SortOptions::default(); + let sort_options_not = SortOptions::default().not(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let required_columns = [col_b.clone(), col_a.clone()]; + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); + eq_properties.add_new_orderings([vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: sort_options_not, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: sort_options, + }, + ]]); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + assert_eq!(idxs, vec![0, 1]); + assert_eq!( + result, + vec![ + PhysicalSortExpr { + expr: col_b.clone(), + options: sort_options_not + }, + PhysicalSortExpr { + expr: col_a.clone(), + options: sort_options + } + ] + ); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let required_columns = [col_b.clone(), col_a.clone()]; + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); + eq_properties.add_new_orderings([ + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options: sort_options, + }], + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: sort_options_not, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: sort_options, + }, + ], + ]); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + assert_eq!(idxs, vec![0, 1]); + assert_eq!( + result, + vec![ + PhysicalSortExpr { + expr: col_b.clone(), + options: sort_options_not + }, + PhysicalSortExpr { + expr: col_a.clone(), + options: sort_options + } + ] + ); + + let required_columns = [ + Arc::new(Column::new("b", 1)) as _, + Arc::new(Column::new("a", 0)) as _, + ]; + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); + + // not satisfied orders + eq_properties.add_new_orderings([vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: sort_options_not, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options: sort_options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: sort_options, + }, + ]]); + let (_, idxs) = eq_properties.find_longest_permutation(&required_columns); + assert_eq!(idxs, vec![0]); + + Ok(()) + } + + #[test] + fn test_update_ordering() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + ]); + + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + // b=a (e.g they are aliases) + eq_properties.add_equal_conditions(col_b, col_a); + // [b ASC], [d ASC] + eq_properties.add_new_orderings(vec![ + vec![PhysicalSortExpr { + expr: col_b.clone(), + options: option_asc, + }], + vec![PhysicalSortExpr { + expr: col_d.clone(), + options: option_asc, + }], + ]); + + let test_cases = vec![ + // d + b + ( + Arc::new(BinaryExpr::new( + col_d.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc, + SortProperties::Ordered(option_asc), + ), + // b + (col_b.clone(), SortProperties::Ordered(option_asc)), + // a + (col_a.clone(), SortProperties::Ordered(option_asc)), + // a + c + ( + Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_c.clone(), + )), + SortProperties::Unordered, + ), + ]; + for (expr, expected) in test_cases { + let leading_orderings = eq_properties + .oeq_class() + .iter() + .flat_map(|ordering| ordering.first().cloned()) + .collect::>(); + let expr_ordering = eq_properties.get_expr_ordering(expr.clone()); + let err_msg = format!( + "expr:{:?}, expected: {:?}, actual: {:?}, leading_orderings: {leading_orderings:?}", + expr, expected, expr_ordering.state + ); + assert_eq!(expr_ordering.state, expected, "{}", err_msg); + } + + Ok(()) + } + + #[test] + fn test_find_longest_permutation_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let exprs = exprs.into_iter().cloned().collect::>(); + let (ordering, indices) = + eq_properties.find_longest_permutation(&exprs); + // Make sure that find_longest_permutation return values are consistent + let ordering2 = indices + .iter() + .zip(ordering.iter()) + .map(|(&idx, sort_expr)| PhysicalSortExpr { + expr: exprs[idx].clone(), + options: sort_expr.options, + }) + .collect::>(); + assert_eq!( + ordering, ordering2, + "indices and lexicographical ordering do not match" + ); + + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + assert_eq!(ordering.len(), indices.len(), "{}", err_msg); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + table_data_with_properties.clone(), + )?, + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + #[test] + fn test_find_longest_permutation() -> Result<()> { + // Schema satisfies following orderings: + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + // and + // Column [a=c] (e.g they are aliases). + // At below we add [d ASC, h DESC] also, for test purposes + let (test_schema, mut eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_h = &col("h", &test_schema)?; + // a + d + let a_plus_d = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + // [d ASC, h ASC] also satisfies schema. + eq_properties.add_new_orderings([vec![ + PhysicalSortExpr { + expr: col_d.clone(), + options: option_asc, + }, + PhysicalSortExpr { + expr: col_h.clone(), + options: option_desc, + }, + ]]); + let test_cases = vec![ + // TEST CASE 1 + (vec![col_a], vec![(col_a, option_asc)]), + // TEST CASE 2 + (vec![col_c], vec![(col_c, option_asc)]), + // TEST CASE 3 + ( + vec![col_d, col_e, col_b], + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_b, option_asc), + ], + ), + // TEST CASE 4 + (vec![col_b], vec![]), + // TEST CASE 5 + (vec![col_d], vec![(col_d, option_asc)]), + // TEST CASE 5 + (vec![&a_plus_d], vec![(&a_plus_d, option_asc)]), + // TEST CASE 6 + ( + vec![col_b, col_d], + vec![(col_d, option_asc), (col_b, option_asc)], + ), + // TEST CASE 6 + ( + vec![col_c, col_e], + vec![(col_c, option_asc), (col_e, option_desc)], + ), + ]; + for (exprs, expected) in test_cases { + let exprs = exprs.into_iter().cloned().collect::>(); + let expected = convert_to_sort_exprs(&expected); + let (actual, _) = eq_properties.find_longest_permutation(&exprs); + assert_eq!(actual, expected); + } + + Ok(()) + } + #[test] + fn test_get_meet_ordering() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let eq_properties = EquivalenceProperties::new(schema); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let tests_cases = vec![ + // Get meet ordering between [a ASC] and [a ASC, b ASC] + // result should be [a ASC] + ( + vec![(col_a, option_asc)], + vec![(col_a, option_asc), (col_b, option_asc)], + Some(vec![(col_a, option_asc)]), + ), + // Get meet ordering between [a ASC] and [a DESC] + // result should be None. + (vec![(col_a, option_asc)], vec![(col_a, option_desc)], None), + // Get meet ordering between [a ASC, b ASC] and [a ASC, b DESC] + // result should be [a ASC]. + ( + vec![(col_a, option_asc), (col_b, option_asc)], + vec![(col_a, option_asc), (col_b, option_desc)], + Some(vec![(col_a, option_asc)]), + ), + ]; + for (lhs, rhs, expected) in tests_cases { + let lhs = convert_to_sort_exprs(&lhs); + let rhs = convert_to_sort_exprs(&rhs); + let expected = expected.map(|expected| convert_to_sort_exprs(&expected)); + let finer = eq_properties.get_meet_ordering(&lhs, &rhs); + assert_eq!(finer, expected) + } + + Ok(()) + } + + #[test] + fn test_get_finer() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let eq_properties = EquivalenceProperties::new(schema); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + // First entry, and second entry are the physical sort requirement that are argument for get_finer_requirement. + // Third entry is the expected result. + let tests_cases = vec![ + // Get finer requirement between [a Some(ASC)] and [a None, b Some(ASC)] + // result should be [a Some(ASC), b Some(ASC)] + ( + vec![(col_a, Some(option_asc))], + vec![(col_a, None), (col_b, Some(option_asc))], + Some(vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))]), + ), + // Get finer requirement between [a Some(ASC), b Some(ASC), c Some(ASC)] and [a Some(ASC), b Some(ASC)] + // result should be [a Some(ASC), b Some(ASC), c Some(ASC)] + ( + vec![ + (col_a, Some(option_asc)), + (col_b, Some(option_asc)), + (col_c, Some(option_asc)), + ], + vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], + Some(vec![ + (col_a, Some(option_asc)), + (col_b, Some(option_asc)), + (col_c, Some(option_asc)), + ]), + ), + // Get finer requirement between [a Some(ASC), b Some(ASC)] and [a Some(ASC), b Some(DESC)] + // result should be None + ( + vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], + vec![(col_a, Some(option_asc)), (col_b, Some(option_desc))], + None, + ), + ]; + for (lhs, rhs, expected) in tests_cases { + let lhs = convert_to_sort_reqs(&lhs); + let rhs = convert_to_sort_reqs(&rhs); + let expected = expected.map(|expected| convert_to_sort_reqs(&expected)); + let finer = eq_properties.get_finer_requirement(&lhs, &rhs); + assert_eq!(finer, expected) + } + + Ok(()) + } + + #[test] + fn test_normalize_sort_reqs() -> Result<()> { + // Schema satisfies following properties + // a=c + // and following orderings are valid + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function + let requirements = vec![ + ( + vec![(col_a, Some(option_asc))], + vec![(col_a, Some(option_asc))], + ), + ( + vec![(col_a, Some(option_desc))], + vec![(col_a, Some(option_desc))], + ), + (vec![(col_a, None)], vec![(col_a, None)]), + // Test whether equivalence works as expected + ( + vec![(col_c, Some(option_asc))], + vec![(col_a, Some(option_asc))], + ), + (vec![(col_c, None)], vec![(col_a, None)]), + // Test whether ordering equivalence works as expected + ( + vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], + vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], + ), + ( + vec![(col_d, None), (col_b, None)], + vec![(col_d, None), (col_b, None)], + ), + ( + vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], + vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], + ), + // We should be able to normalize in compatible requirements also (not exactly equal) + ( + vec![(col_e, Some(option_desc)), (col_f, None)], + vec![(col_e, Some(option_desc)), (col_f, None)], + ), + ( + vec![(col_e, None), (col_f, None)], + vec![(col_e, None), (col_f, None)], + ), + ]; + + for (reqs, expected_normalized) in requirements.into_iter() { + let req = convert_to_sort_reqs(&reqs); + let expected_normalized = convert_to_sort_reqs(&expected_normalized); + + assert_eq!( + eq_properties.normalize_sort_requirements(&req), + expected_normalized + ); + } + + Ok(()) + } + + #[test] + fn test_schema_normalize_sort_requirement_with_equivalence() -> Result<()> { + let option1 = SortOptions { + descending: false, + nulls_first: false, + }; + // Assume that column a and c are aliases. + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + + // Test cases for equivalence normalization + // First entry in the tuple is PhysicalSortRequirement, second entry in the tuple is + // expected PhysicalSortRequirement after normalization. + let test_cases = vec![ + (vec![(col_a, Some(option1))], vec![(col_a, Some(option1))]), + // In the normalized version column c should be replace with column a + (vec![(col_c, Some(option1))], vec![(col_a, Some(option1))]), + (vec![(col_c, None)], vec![(col_a, None)]), + (vec![(col_d, Some(option1))], vec![(col_d, Some(option1))]), + ]; + for (reqs, expected) in test_cases.into_iter() { + let reqs = convert_to_sort_reqs(&reqs); + let expected = convert_to_sort_reqs(&expected); + + let normalized = eq_properties.normalize_sort_requirements(&reqs); + assert!( + expected.eq(&normalized), + "error in test: reqs: {reqs:?}, expected: {expected:?}, normalized: {normalized:?}" + ); + } + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index f75c2f951f56..8c4078dbce8c 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -20,11 +20,9 @@ mod kernels; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use crate::array_expressions::{ - array_append, array_concat, array_has_all, array_prepend, -}; +use crate::array_expressions::array_has_all; +use crate::expressions::datum::{apply, apply_cmp}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; -use crate::intervals::{apply_operator, Interval}; use crate::physical_expr::down_cast_any_ref; use crate::sort_properties::SortProperties; use crate::PhysicalExpr; @@ -38,12 +36,13 @@ use arrow::compute::kernels::comparison::regexp_is_match_utf8_scalar; use arrow::compute::kernels::concat_elements::concat_elements_utf8; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; + use datafusion_common::cast::as_boolean_array; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; use datafusion_expr::type_coercion::binary::get_result_type; use datafusion_expr::{ColumnarValue, Operator}; -use crate::expressions::datum::{apply, apply_cmp}; use kernels::{ bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar, bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn, @@ -299,16 +298,16 @@ impl PhysicalExpr for BinaryExpr { }; if let Some(result) = scalar_result { - return result.map(|a| ColumnarValue::Array(a)); + return result.map(ColumnarValue::Array); } // if both arrays or both literals - extract arrays and continue execution let (left, right) = ( - lhs.into_array(batch.num_rows()), - rhs.into_array(batch.num_rows()), + lhs.into_array(batch.num_rows())?, + rhs.into_array(batch.num_rows())?, ); self.evaluate_with_resolved_args(left, &left_data_type, right, &right_data_type) - .map(|a| ColumnarValue::Array(a)) + .map(ColumnarValue::Array) } fn children(&self) -> Vec> { @@ -338,32 +337,102 @@ impl PhysicalExpr for BinaryExpr { &self, interval: &Interval, children: &[&Interval], - ) -> Result>> { + ) -> Result>> { // Get children intervals. let left_interval = children[0]; let right_interval = children[1]; - let (left, right) = if self.op.is_logic_operator() { - // TODO: Currently, this implementation only supports the AND operator - // and does not require any further propagation. In the future, - // upon adding support for additional logical operators, this - // method will require modification to support propagating the - // changes accordingly. - return Ok(vec![]); - } else if self.op.is_comparison_operator() { - if interval == &Interval::CERTAINLY_FALSE { - // TODO: We will handle strictly false clauses by negating - // the comparison operator (e.g. GT to LE, LT to GE) - // once open/closed intervals are supported. - return Ok(vec![]); + if self.op.eq(&Operator::And) { + if interval.eq(&Interval::CERTAINLY_TRUE) { + // A certainly true logical conjunction can only derive from possibly + // true operands. Otherwise, we prove infeasability. + Ok((!left_interval.eq(&Interval::CERTAINLY_FALSE) + && !right_interval.eq(&Interval::CERTAINLY_FALSE)) + .then(|| vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_TRUE])) + } else if interval.eq(&Interval::CERTAINLY_FALSE) { + // If the logical conjunction is certainly false, one of the + // operands must be false. However, it's not always possible to + // determine which operand is false, leading to different scenarios. + + // If one operand is certainly true and the other one is uncertain, + // then the latter must be certainly false. + if left_interval.eq(&Interval::CERTAINLY_TRUE) + && right_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_TRUE, + Interval::CERTAINLY_FALSE, + ])) + } else if right_interval.eq(&Interval::CERTAINLY_TRUE) + && left_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_FALSE, + Interval::CERTAINLY_TRUE, + ])) + } + // If both children are uncertain, or if one is certainly false, + // we cannot conclusively refine their intervals. In this case, + // propagation does not result in any interval changes. + else { + Ok(Some(vec![])) + } + } else { + // An uncertain logical conjunction result can not shrink the + // end-points of its children. + Ok(Some(vec![])) } - // Propagate the comparison operator. - propagate_comparison(&self.op, left_interval, right_interval)? + } else if self.op.eq(&Operator::Or) { + if interval.eq(&Interval::CERTAINLY_FALSE) { + // A certainly false logical conjunction can only derive from certainly + // false operands. Otherwise, we prove infeasability. + Ok((!left_interval.eq(&Interval::CERTAINLY_TRUE) + && !right_interval.eq(&Interval::CERTAINLY_TRUE)) + .then(|| vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE])) + } else if interval.eq(&Interval::CERTAINLY_TRUE) { + // If the logical disjunction is certainly true, one of the + // operands must be true. However, it's not always possible to + // determine which operand is true, leading to different scenarios. + + // If one operand is certainly false and the other one is uncertain, + // then the latter must be certainly true. + if left_interval.eq(&Interval::CERTAINLY_FALSE) + && right_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_FALSE, + Interval::CERTAINLY_TRUE, + ])) + } else if right_interval.eq(&Interval::CERTAINLY_FALSE) + && left_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_TRUE, + Interval::CERTAINLY_FALSE, + ])) + } + // If both children are uncertain, or if one is certainly true, + // we cannot conclusively refine their intervals. In this case, + // propagation does not result in any interval changes. + else { + Ok(Some(vec![])) + } + } else { + // An uncertain logical disjunction result can not shrink the + // end-points of its children. + Ok(Some(vec![])) + } + } else if self.op.is_comparison_operator() { + Ok( + propagate_comparison(&self.op, interval, left_interval, right_interval)? + .map(|(left, right)| vec![left, right]), + ) } else { - // Propagate the arithmetic operator. - propagate_arithmetic(&self.op, interval, left_interval, right_interval)? - }; - Ok(vec![left, right]) + Ok( + propagate_arithmetic(&self.op, interval, left_interval, right_interval)? + .map(|(left, right)| vec![left, right]), + ) + } } fn dyn_hash(&self, state: &mut dyn Hasher) { @@ -380,7 +449,7 @@ impl PhysicalExpr for BinaryExpr { Operator::Minus => left_child.sub(right_child), Operator::Gt | Operator::GtEq => left_child.gt_or_gteq(right_child), Operator::Lt | Operator::LtEq => right_child.gt_or_gteq(left_child), - Operator::And => left_child.and(right_child), + Operator::And | Operator::Or => left_child.and_or(right_child), _ => SortProperties::Unordered, } } @@ -527,12 +596,7 @@ impl BinaryExpr { BitwiseXor => bitwise_xor_dyn(left, right), BitwiseShiftRight => bitwise_shift_right_dyn(left, right), BitwiseShiftLeft => bitwise_shift_left_dyn(left, right), - StringConcat => match (left_data_type, right_data_type) { - (DataType::List(_), DataType::List(_)) => array_concat(&[left, right]), - (DataType::List(_), _) => array_append(&[left, right]), - (_, DataType::List(_)) => array_prepend(&[left, right]), - _ => binary_string_array_op!(left, right, concat_elements), - }, + StringConcat => binary_string_array_op!(left, right, concat_elements), AtArrow => array_has_all(&[left, right]), ArrowAt => array_has_all(&[right, left]), } @@ -558,8 +622,7 @@ mod tests { use arrow::datatypes::{ ArrowNumericType, Decimal128Type, Field, Int32Type, SchemaRef, }; - use arrow_schema::ArrowError; - use datafusion_common::Result; + use datafusion_common::{plan_datafusion_err, Result}; use datafusion_expr::type_coercion::binary::get_input_types; /// Performs a binary operation, applying any type coercion necessary @@ -597,7 +660,10 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?; - let result = lt.evaluate(&batch)?.into_array(batch.num_rows()); + let result = lt + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.len(), 5); let expected = [false, false, true, true, true]; @@ -641,7 +707,10 @@ mod tests { assert_eq!("a@0 < b@1 OR a@0 = b@1", format!("{expr}")); - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.len(), 5); let expected = [true, true, false, true, false]; @@ -685,7 +754,7 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $C_TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $C_TYPE); @@ -2138,7 +2207,10 @@ mod tests { let arithmetic_op = binary_op(col("a", &schema)?, op, col("b", &schema)?, &schema)?; let batch = RecordBatch::try_new(schema, data)?; - let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = arithmetic_op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), &expected); Ok(()) @@ -2154,7 +2226,10 @@ mod tests { let lit = Arc::new(Literal::new(literal)); let arithmetic_op = binary_op(col("a", &schema)?, op, lit, &schema)?; let batch = RecordBatch::try_new(schema, data)?; - let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = arithmetic_op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(&result, &expected); Ok(()) @@ -2170,7 +2245,10 @@ mod tests { let op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?; let data: Vec = vec![left.clone(), right.clone()]; let batch = RecordBatch::try_new(schema.clone(), data)?; - let result = op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), &expected); Ok(()) @@ -2187,7 +2265,10 @@ mod tests { let scalar = lit(scalar.clone()); let op = binary_op(scalar, op, col("a", schema)?, schema)?; let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; - let result = op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), expected); Ok(()) @@ -2204,7 +2285,10 @@ mod tests { let scalar = lit(scalar.clone()); let op = binary_op(col("a", schema)?, op, scalar, schema)?; let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; - let result = op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), expected); Ok(()) @@ -2776,7 +2860,8 @@ mod tests { let result = expr .evaluate(&batch) .expect("evaluation") - .into_array(batch.num_rows()); + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let expected: Int32Array = input .into_iter() @@ -3255,7 +3340,10 @@ mod tests { let arithmetic_op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?; let data: Vec = vec![left.clone(), right.clone()]; let batch = RecordBatch::try_new(schema.clone(), data)?; - let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = arithmetic_op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), expected.as_ref()); Ok(()) @@ -3512,10 +3600,9 @@ mod tests { ) .unwrap_err(); - assert!( - matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)), - "{err}" - ); + let _expected = plan_datafusion_err!("Divide by zero"); + + assert!(matches!(err, ref _expected), "{err}"); // decimal let schema = Arc::new(Schema::new(vec![ @@ -3537,10 +3624,7 @@ mod tests { ) .unwrap_err(); - assert!( - matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)), - "{err}" - ); + assert!(matches!(err, ref _expected), "{err}"); Ok(()) } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index a2395c4a0ca2..52fb85657f4e 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -126,7 +126,7 @@ impl CaseExpr { let return_type = self.data_type(&batch.schema())?; let expr = self.expr.as_ref().unwrap(); let base_value = expr.evaluate(batch)?; - let base_value = base_value.into_array(batch.num_rows()); + let base_value = base_value.into_array(batch.num_rows())?; let base_nulls = is_null(base_value.as_ref())?; // start with nulls as default output @@ -137,7 +137,7 @@ impl CaseExpr { let when_value = self.when_then_expr[i] .0 .evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows()); + let when_value = when_value.into_array(batch.num_rows())?; // build boolean array representing which rows match the "when" value let when_match = eq(&when_value, &base_value)?; // Treat nulls as false @@ -145,6 +145,8 @@ impl CaseExpr { 0 => Cow::Borrowed(&when_match), _ => Cow::Owned(prep_null_mask_filter(&when_match)), }; + // Make sure we only consider rows that have not been matched yet + let when_match = and(&when_match, &remainder)?; let then_value = self.when_then_expr[i] .1 @@ -153,7 +155,7 @@ impl CaseExpr { ColumnarValue::Scalar(value) if value.is_null() => { new_null_array(&return_type, batch.num_rows()) } - _ => then_value.into_array(batch.num_rows()), + _ => then_value.into_array(batch.num_rows())?, }; current_value = @@ -170,7 +172,7 @@ impl CaseExpr { remainder = or(&base_nulls, &remainder)?; let else_ = expr .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows()); + .into_array(batch.num_rows())?; current_value = zip(&remainder, else_.as_ref(), current_value.as_ref())?; } @@ -194,7 +196,7 @@ impl CaseExpr { let when_value = self.when_then_expr[i] .0 .evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows()); + let when_value = when_value.into_array(batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|e| { DataFusionError::Context( "WHEN expression did not return a BooleanArray".to_string(), @@ -206,6 +208,8 @@ impl CaseExpr { 0 => Cow::Borrowed(when_value), _ => Cow::Owned(prep_null_mask_filter(when_value)), }; + // Make sure we only consider rows that have not been matched yet + let when_value = and(&when_value, &remainder)?; let then_value = self.when_then_expr[i] .1 @@ -214,7 +218,7 @@ impl CaseExpr { ColumnarValue::Scalar(value) if value.is_null() => { new_null_array(&return_type, batch.num_rows()) } - _ => then_value.into_array(batch.num_rows()), + _ => then_value.into_array(batch.num_rows())?, }; current_value = @@ -231,7 +235,7 @@ impl CaseExpr { .unwrap_or_else(|_| e.clone()); let else_ = expr .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows()); + .into_array(batch.num_rows())?; current_value = zip(&remainder, else_.as_ref(), current_value.as_ref())?; } @@ -425,7 +429,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); @@ -453,7 +460,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = @@ -485,7 +495,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -523,7 +536,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); @@ -551,7 +567,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = @@ -583,7 +602,10 @@ mod tests { Some(x), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -629,7 +651,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = @@ -661,7 +686,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -693,7 +721,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -721,7 +752,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 9390089063a0..0c4ed3c12549 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -20,18 +20,16 @@ use std::fmt; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::intervals::Interval; use crate::physical_expr::down_cast_any_ref; use crate::sort_properties::SortProperties; use crate::PhysicalExpr; -use arrow::compute; -use arrow::compute::{kernels, CastOptions}; +use arrow::compute::{can_cast_types, kernels, CastOptions}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use compute::can_cast_types; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { @@ -73,6 +71,11 @@ impl CastExpr { pub fn cast_type(&self) -> &DataType { &self.cast_type } + + /// The cast options + pub fn cast_options(&self) -> &CastOptions<'static> { + &self.cast_options + } } impl fmt::Display for CastExpr { @@ -124,21 +127,20 @@ impl PhysicalExpr for CastExpr { &self, interval: &Interval, children: &[&Interval], - ) -> Result>> { + ) -> Result>> { let child_interval = children[0]; // Get child's datatype: - let cast_type = child_interval.get_datatype()?; - Ok(vec![Some( - interval.cast_to(&cast_type, &self.cast_options)?, - )]) + let cast_type = child_interval.data_type(); + Ok(Some( + vec![interval.cast_to(&cast_type, &self.cast_options)?], + )) } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.expr.hash(&mut s); self.cast_type.hash(&mut s); - // Add `self.cast_options` when hash is available - // https://github.com/apache/arrow-rs/pull/4395 + self.cast_options.hash(&mut s); } /// A [`CastExpr`] preserves the ordering of its child. @@ -154,8 +156,7 @@ impl PartialEq for CastExpr { .map(|x| { self.expr.eq(&x.expr) && self.cast_type == x.cast_type - // TODO: Use https://github.com/apache/arrow-rs/issues/2966 when available - && self.cast_options.safe == x.cast_options.safe + && self.cast_options == x.cast_options }) .unwrap_or(false) } @@ -173,7 +174,20 @@ pub fn cast_column( kernels::cast::cast_with_options(array, cast_type, &cast_options)?, )), ColumnarValue::Scalar(scalar) => { - let scalar_array = scalar.to_array(); + let scalar_array = if cast_type + == &DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None) + { + if let ScalarValue::Float64(Some(float_ts)) = scalar { + ScalarValue::Int64( + Some((float_ts * 1_000_000_000_f64).trunc() as i64), + ) + .to_array()? + } else { + scalar.to_array()? + } + } else { + scalar.to_array()? + }; let cast_array = kernels::cast::cast_with_options( &scalar_array, cast_type, @@ -198,7 +212,10 @@ pub fn cast_with_options( let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { Ok(expr.clone()) - } else if can_cast_types(&expr_type, &cast_type) { + } else if can_cast_types(&expr_type, &cast_type) + || (expr_type == DataType::Float64 + && cast_type == DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None)) + { Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) } else { not_impl_err!("Unsupported CAST from {expr_type:?} to {cast_type:?}") @@ -221,6 +238,7 @@ pub fn cast( mod tests { use super::*; use crate::expressions::col; + use arrow::{ array::{ Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array, @@ -229,6 +247,7 @@ mod tests { }, datatypes::*, }; + use datafusion_common::Result; // runs an end-to-end test of physical type cast @@ -258,7 +277,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); @@ -307,7 +329,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); @@ -669,7 +694,11 @@ mod tests { // Ensure a useful error happens at plan time if invalid casts are used let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let result = cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary); + let result = cast( + col("a", &schema).unwrap(), + &schema, + DataType::Interval(IntervalUnit::MonthDayNano), + ); result.expect_err("expected Invalid CAST"); } diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index b7b5895db6d3..62da8ff9ed44 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -28,7 +28,6 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::plan_err; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; @@ -176,7 +175,7 @@ impl PhysicalExpr for UnKnownColumn { /// Evaluate the expression fn evaluate(&self, _batch: &RecordBatch) -> Result { - plan_err!("UnKnownColumn::evaluate() should not be called") + internal_err!("UnKnownColumn::evaluate() should not be called") } fn children(&self) -> Vec> { diff --git a/datafusion/physical-expr/src/expressions/datum.rs b/datafusion/physical-expr/src/expressions/datum.rs index f57cbbd4ffa3..2bb79922cfec 100644 --- a/datafusion/physical-expr/src/expressions/datum.rs +++ b/datafusion/physical-expr/src/expressions/datum.rs @@ -34,14 +34,14 @@ pub(crate) fn apply( (ColumnarValue::Array(left), ColumnarValue::Array(right)) => { Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?)) } - (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => { - Ok(ColumnarValue::Array(f(&left.to_scalar(), &right.as_ref())?)) - } - (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => { - Ok(ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar())?)) - } + (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok( + ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?), + ), + (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok( + ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?), + ), (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { - let array = f(&left.to_scalar(), &right.to_scalar())?; + let array = f(&left.to_scalar()?, &right.to_scalar()?)?; let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?; Ok(ColumnarValue::Scalar(scalar)) } diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index ab15356dc212..43fd5a812a16 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -18,16 +18,19 @@ //! get field of a `ListArray` use crate::PhysicalExpr; -use arrow::array::Array; use datafusion_common::exec_err; use crate::array_expressions::{array_element, array_slice}; use crate::physical_expr::down_cast_any_ref; use arrow::{ + array::{Array, Scalar, StringArray}, datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::{cast::as_struct_array, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + cast::{as_map_array, as_struct_array}, + DataFusionError, Result, ScalarValue, +}; use datafusion_expr::{field_util::GetFieldAccessSchema, ColumnarValue}; use std::fmt::Debug; use std::hash::{Hash, Hasher}; @@ -107,7 +110,7 @@ impl GetIndexedFieldExpr { Self::new( arg, GetFieldAccessExpr::NamedStructField { - name: ScalarValue::Utf8(Some(name.into())), + name: ScalarValue::from(name.into()), }, ) } @@ -180,9 +183,17 @@ impl PhysicalExpr for GetIndexedFieldExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let array = self.arg.evaluate(batch)?.into_array(batch.num_rows()); + let array = self.arg.evaluate(batch)?.into_array(batch.num_rows())?; match &self.field { GetFieldAccessExpr::NamedStructField{name} => match (array.data_type(), name) { + (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { + let map_array = as_map_array(array.as_ref())?; + let key_scalar = Scalar::new(StringArray::from(vec![k.clone()])); + let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; + let entries = arrow::compute::filter(map_array.entries(), &keys)?; + let entries_struct_array = as_struct_array(entries.as_ref())?; + Ok(ColumnarValue::Array(entries_struct_array.column(1).clone())) + } (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { let as_struct_array = as_struct_array(&array)?; match as_struct_array.column_by_name(k) { @@ -199,7 +210,7 @@ impl PhysicalExpr for GetIndexedFieldExpr { with utf8 indexes. Tried {dt:?} with {name:?} index"), }, GetFieldAccessExpr::ListIndex{key} => { - let key = key.evaluate(batch)?.into_array(batch.num_rows()); + let key = key.evaluate(batch)?.into_array(batch.num_rows())?; match (array.data_type(), key.data_type()) { (DataType::List(_), DataType::Int64) => Ok(ColumnarValue::Array(array_element(&[ array, key @@ -213,8 +224,8 @@ impl PhysicalExpr for GetIndexedFieldExpr { } }, GetFieldAccessExpr::ListRange{start, stop} => { - let start = start.evaluate(batch)?.into_array(batch.num_rows()); - let stop = stop.evaluate(batch)?.into_array(batch.num_rows()); + let start = start.evaluate(batch)?.into_array(batch.num_rows())?; + let stop = stop.evaluate(batch)?.into_array(batch.num_rows())?; match (array.data_type(), start.data_type(), stop.data_type()) { (DataType::List(_), DataType::Int64, DataType::Int64) => Ok(ColumnarValue::Array(array_slice(&[ array, start, stop @@ -315,7 +326,10 @@ mod tests { // only one row should be processed let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?; let expr = Arc::new(GetIndexedFieldExpr::new_field(expr, "a")); - let result = expr.evaluate(&batch)?.into_array(1); + let result = expr + .evaluate(&batch)? + .into_array(1) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); assert_eq!(boolean, result.clone()); @@ -372,7 +386,10 @@ mod tests { vec![Arc::new(list_col), Arc::new(key_col)], )?; let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); - let result = expr.evaluate(&batch)?.into_array(1); + let result = expr + .evaluate(&batch)? + .into_array(1) + .expect("Failed to convert to array"); let result = as_string_array(&result).expect("failed to downcast to ListArray"); let expected = StringArray::from(expected_list); assert_eq!(expected, result.clone()); @@ -408,7 +425,10 @@ mod tests { vec![Arc::new(list_col), Arc::new(start_col), Arc::new(stop_col)], )?; let expr = Arc::new(GetIndexedFieldExpr::new_range(expr, start, stop)); - let result = expr.evaluate(&batch)?.into_array(1); + let result = expr + .evaluate(&batch)? + .into_array(1) + .expect("Failed to convert to array"); let result = as_list_array(&result).expect("failed to downcast to ListArray"); let (expected, _, _) = build_list_arguments(expected_list, vec![None], vec![None]); @@ -429,8 +449,11 @@ mod tests { vec![Arc::new(list_builder.finish()), key_array], )?; let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - assert!(result.is_null(0)); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + assert!(result.is_empty()); Ok(()) } @@ -450,7 +473,10 @@ mod tests { vec![Arc::new(list_builder.finish()), Arc::new(key_array)], )?; let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); - let result = expr.evaluate(&batch)?.into_array(1); + let result = expr + .evaluate(&batch)? + .into_array(1) + .expect("Failed to convert to array"); assert!(result.is_null(0)); Ok(()) } diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index c92bbbb74f16..1a1634081c38 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -17,17 +17,14 @@ //! Implementation of `InList` expressions: [`InListExpr`] -use ahash::RandomState; -use datafusion_common::exec_err; use std::any::Any; use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::hash_utils::HashValue; -use crate::physical_expr::down_cast_any_ref; -use crate::utils::expr_list_eq_any_order; +use crate::physical_expr::{down_cast_any_ref, physical_exprs_bag_equal}; use crate::PhysicalExpr; + use arrow::array::*; use arrow::buffer::BooleanBuffer; use arrow::compute::kernels::boolean::{not, or_kleene}; @@ -37,11 +34,16 @@ use arrow::datatypes::*; use arrow::record_batch::RecordBatch; use arrow::util::bit_iterator::BitIndexIterator; use arrow::{downcast_dictionary_array, downcast_primitive_array}; +use datafusion_common::cast::{ + as_boolean_array, as_generic_binary_array, as_string_array, +}; +use datafusion_common::hash_utils::HashValue; use datafusion_common::{ - cast::{as_boolean_array, as_generic_binary_array, as_string_array}, - internal_err, not_impl_err, DataFusionError, Result, ScalarValue, + exec_err, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::ColumnarValue; + +use ahash::RandomState; use hashbrown::hash_map::RawEntryMut; use hashbrown::HashMap; @@ -347,17 +349,18 @@ impl PhysicalExpr for InListExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { + let num_rows = batch.num_rows(); let value = self.expr.evaluate(batch)?; let r = match &self.static_filter { - Some(f) => f.contains(value.into_array(1).as_ref(), self.negated)?, + Some(f) => f.contains(value.into_array(num_rows)?.as_ref(), self.negated)?, None => { - let value = value.into_array(batch.num_rows()); + let value = value.into_array(num_rows)?; let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold( - BooleanArray::new(BooleanBuffer::new_unset(batch.num_rows()), None), + BooleanArray::new(BooleanBuffer::new_unset(num_rows), None), |result, expr| -> Result { Ok(or_kleene( &result, - &eq(&value, &expr?.into_array(batch.num_rows()))?, + &eq(&value, &expr?.into_array(num_rows)?)?, )?) }, )?; @@ -407,7 +410,7 @@ impl PartialEq for InListExpr { .downcast_ref::() .map(|x| { self.expr.eq(&x.expr) - && expr_list_eq_any_order(&self.list, &x.list) + && physical_exprs_bag_equal(&self.list, &x.list) && self.negated == x.negated }) .unwrap_or(false) @@ -499,7 +502,10 @@ mod tests { ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{ let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?; let expr = in_list(cast_expr, cast_list_exprs, $NEGATED, $SCHEMA).unwrap(); - let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows()); + let result = expr + .evaluate(&$BATCH)? + .into_array($BATCH.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); let expected = &BooleanArray::from($EXPECTED); @@ -1262,4 +1268,52 @@ mod tests { Ok(()) } + + #[test] + fn in_list_no_cols() -> Result<()> { + // test logic when the in_list expression doesn't have any columns + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![Some(1), Some(2), None]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + let list = vec![lit(ScalarValue::from(1i32)), lit(ScalarValue::from(6i32))]; + + // 1 IN (1, 6) + let expr = lit(ScalarValue::Int32(Some(1))); + in_list!( + batch, + list.clone(), + &false, + // should have three outputs, as the input batch has three rows + vec![Some(true), Some(true), Some(true)], + expr, + &schema + ); + + // 2 IN (1, 6) + let expr = lit(ScalarValue::Int32(Some(2))); + in_list!( + batch, + list.clone(), + &false, + // should have three outputs, as the input batch has three rows + vec![Some(false), Some(false), Some(false)], + expr, + &schema + ); + + // NULL IN (1, 6) + let expr = lit(ScalarValue::Int32(None)); + in_list!( + batch, + list.clone(), + &false, + // should have three outputs, as the input batch has three rows + vec![None, None, None], + expr, + &schema + ); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index da717a517fb3..2e6a2bec9cab 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -132,7 +132,10 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a is not null" - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index ee7897edd4de..3ad4058dd649 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -134,7 +134,10 @@ mod tests { let expr = is_null(col("a", &schema)?).unwrap(); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index e833eabbfff2..37452e278484 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -201,7 +201,10 @@ mod test { )?; // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); let expected = &BooleanArray::from($VEC); diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 91cb23d5864e..cd3b51f09105 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -131,7 +131,10 @@ mod tests { let literal_expr = lit(42i32); assert_eq!("42", format!("{literal_expr}")); - let literal_array = literal_expr.evaluate(&batch)?.into_array(batch.num_rows()); + let literal_array = literal_expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let literal_array = as_int32_array(&literal_array)?; // note that the contents of the literal array are unrelated to the batch contents except for the length of the array diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index c44b3cf01d36..b6d0ad5b9104 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -63,6 +63,7 @@ pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; pub use crate::aggregate::stddev::{Stddev, StddevPop}; +pub use crate::aggregate::string_agg::StringAgg; pub use crate::aggregate::sum::Sum; pub use crate::aggregate::sum_distinct::DistinctSum; pub use crate::aggregate::variance::{Variance, VariancePop}; @@ -247,8 +248,10 @@ pub(crate) mod tests { let expr = agg.expressions(); let values = expr .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect::>>()?; accum.update_batch(&values)?; accum.evaluate() @@ -262,8 +265,10 @@ pub(crate) mod tests { let expr = agg.expressions(); let values = expr .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect::>>()?; let indices = vec![0; batch.num_rows()]; accum.update_batch(&values, &indices, None, 1)?; diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 90430cb2bbda..b64b4a0c86de 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -30,10 +30,10 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; - use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::{ - type_coercion::{is_interval, is_null, is_signed_numeric}, + type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp}, ColumnarValue, }; @@ -105,6 +105,34 @@ impl PhysicalExpr for NegativeExpr { self.hash(&mut s); } + /// Given the child interval of a NegativeExpr, it calculates the NegativeExpr's interval. + /// It replaces the upper and lower bounds after multiplying them with -1. + /// Ex: `(a, b]` => `[-b, -a)` + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + Interval::try_new( + children[0].upper().arithmetic_negate()?, + children[0].lower().arithmetic_negate()?, + ) + } + + /// Returns a new [`Interval`] of a NegativeExpr that has the existing `interval` given that + /// given the input interval is known to be `children`. + fn propagate_constraints( + &self, + interval: &Interval, + children: &[&Interval], + ) -> Result>> { + let child_interval = children[0]; + let negated_interval = Interval::try_new( + interval.upper().arithmetic_negate()?, + interval.lower().arithmetic_negate()?, + )?; + + Ok(child_interval + .intersect(negated_interval)? + .map(|result| vec![result])) + } + /// The ordering of a [`NegativeExpr`] is simply the reverse of its child. fn get_ordering(&self, children: &[SortProperties]) -> SortProperties { -children[0] @@ -132,7 +160,10 @@ pub fn negative( let data_type = arg.data_type(input_schema)?; if is_null(&data_type) { Ok(arg) - } else if !is_signed_numeric(&data_type) && !is_interval(&data_type) { + } else if !is_signed_numeric(&data_type) + && !is_interval(&data_type) + && !is_timestamp(&data_type) + { internal_err!( "Can't create negative physical expr for (- '{arg:?}'), the type of child expr is {data_type}, not signed numeric" ) @@ -144,12 +175,14 @@ pub fn negative( #[cfg(test)] mod tests { use super::*; - use crate::expressions::col; - #[allow(unused_imports)] + use crate::expressions::{col, Column}; + use arrow::array::*; use arrow::datatypes::*; use arrow_schema::DataType::{Float32, Float64, Int16, Int32, Int64, Int8}; - use datafusion_common::{cast::as_primitive_array, Result}; + use datafusion_common::cast::as_primitive_array; + use datafusion_common::Result; + use paste::paste; macro_rules! test_array_negative_op { @@ -170,7 +203,7 @@ mod tests { let expected = &paste!{[<$DATA_TY Array>]::from(arr_expected)}; let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array"); let result = as_primitive_array(&result).expect(format!("failed to downcast to {:?}Array", $DATA_TY).as_str()); assert_eq!(result, expected); @@ -187,4 +220,36 @@ mod tests { test_array_negative_op!(Float64, 23456.0f64, 12345.0f64); Ok(()) } + + #[test] + fn test_evaluate_bounds() -> Result<()> { + let negative_expr = NegativeExpr { + arg: Arc::new(Column::new("a", 0)), + }; + let child_interval = Interval::make(Some(-2), Some(1))?; + let negative_expr_interval = Interval::make(Some(-1), Some(2))?; + assert_eq!( + negative_expr.evaluate_bounds(&[&child_interval])?, + negative_expr_interval + ); + Ok(()) + } + + #[test] + fn test_propagate_constraints() -> Result<()> { + let negative_expr = NegativeExpr { + arg: Arc::new(Column::new("a", 0)), + }; + let original_child_interval = Interval::make(Some(-2), Some(3))?; + let negative_expr_interval = Interval::make(Some(0), Some(4))?; + let after_propagation = Some(vec![Interval::make(Some(-2), Some(0))?]); + assert_eq!( + negative_expr.propagate_constraints( + &negative_expr_interval, + &[&original_child_interval] + )?, + after_propagation + ); + Ok(()) + } } diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index 497fb42fe4df..95e6879a6c2d 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -28,7 +28,7 @@ use arrow::{ use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; /// A place holder expression, can not be evaluated. @@ -65,7 +65,7 @@ impl PhysicalExpr for NoOp { } fn evaluate(&self, _batch: &RecordBatch) -> Result { - plan_err!("NoOp::evaluate() should not be called") + internal_err!("NoOp::evaluate() should not be called") } fn children(&self) -> Vec> { diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index c154fad10037..4ceccc6932fe 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -150,7 +150,10 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); assert_eq!(result, expected); diff --git a/datafusion/physical-expr/src/expressions/nullif.rs b/datafusion/physical-expr/src/expressions/nullif.rs index 7bbe9d73d435..dcd883f92965 100644 --- a/datafusion/physical-expr/src/expressions/nullif.rs +++ b/datafusion/physical-expr/src/expressions/nullif.rs @@ -37,7 +37,7 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result { match (lhs, rhs) { (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => { - let rhs = rhs.to_scalar(); + let rhs = rhs.to_scalar()?; let array = nullif(lhs, &eq(&lhs, &rhs)?)?; Ok(ColumnarValue::Array(array)) @@ -47,7 +47,7 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(array)) } (ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => { - let lhs = lhs.to_array_of_size(rhs.len()); + let lhs = lhs.to_array_of_size(rhs.len())?; let array = nullif(&lhs, &eq(&lhs, &rhs)?)?; Ok(ColumnarValue::Array(array)) } @@ -89,7 +89,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(Int32Array::from(vec![ Some(1), @@ -115,7 +115,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(Int32Array::from(vec![ None, @@ -140,7 +140,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(BooleanArray::from(vec![Some(true), None, None])) as ArrayRef; @@ -154,10 +154,10 @@ mod tests { let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Utf8(Some("bar".to_string()))); + let lit_array = ColumnarValue::Scalar(ScalarValue::from("bar")); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(StringArray::from(vec![ Some("foo"), @@ -178,7 +178,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); let result = nullif_func(&[lit_array, a])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(Int32Array::from(vec![ Some(2), @@ -198,7 +198,7 @@ mod tests { let b_eq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); let result_eq = nullif_func(&[a_eq, b_eq])?; - let result_eq = result_eq.into_array(1); + let result_eq = result_eq.into_array(1).expect("Failed to convert to array"); let expected_eq = Arc::new(Int32Array::from(vec![None])) as ArrayRef; @@ -208,7 +208,9 @@ mod tests { let b_neq = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); let result_neq = nullif_func(&[a_neq, b_neq])?; - let result_neq = result_neq.into_array(1); + let result_neq = result_neq + .into_array(1) + .expect("Failed to convert to array"); let expected_neq = Arc::new(Int32Array::from(vec![Some(2i32)])) as ArrayRef; assert_eq!(expected_neq.as_ref(), result_neq.as_ref()); diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index cba026c56513..0f7909097a10 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -89,7 +89,7 @@ impl PhysicalExpr for TryCastExpr { Ok(ColumnarValue::Array(cast)) } ColumnarValue::Scalar(scalar) => { - let array = scalar.to_array(); + let array = scalar.to_array()?; let cast_array = cast_with_options(&array, &self.cast_type, &options)?; let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; Ok(ColumnarValue::Scalar(cast_scalar)) @@ -187,7 +187,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); @@ -235,7 +238,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); @@ -549,7 +555,11 @@ mod tests { // Ensure a useful error happens at plan time if invalid casts are used let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let result = try_cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary); + let result = try_cast( + col("a", &schema).unwrap(), + &schema, + DataType::Interval(IntervalUnit::MonthDayNano), + ); result.expect_err("expected Invalid TRY_CAST"); } diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 5de0dc366b85..53de85843919 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -34,19 +34,19 @@ use crate::execution_props::ExecutionProps; use crate::sort_properties::SortProperties; use crate::{ array_expressions, conditional_expressions, datetime_expressions, - expressions::{cast_column, nullif_func}, - math_expressions, string_expressions, struct_expressions, PhysicalExpr, - ScalarFunctionExpr, + expressions::nullif_func, math_expressions, string_expressions, struct_expressions, + PhysicalExpr, ScalarFunctionExpr, }; use arrow::{ array::ArrayRef, compute::kernels::length::{bit_length, length}, - datatypes::TimeUnit, datatypes::{DataType, Int32Type, Int64Type, Schema}, }; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; +pub use datafusion_expr::FuncMonotonicity; use datafusion_expr::{ - BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, + type_coercion::functions::data_types, BuiltinScalarFunction, ColumnarValue, + ScalarFunctionImplementation, }; use std::ops::Neg; use std::sync::Arc; @@ -64,129 +64,21 @@ pub fn create_physical_expr( .map(|e| e.data_type(input_schema)) .collect::>>()?; + // verify that input data types is consistent with function's `TypeSignature` + data_types(&input_expr_types, &fun.signature())?; + let data_type = fun.return_type(&input_expr_types)?; - let fun_expr: ScalarFunctionImplementation = match fun { - // These functions need args and input schema to pick an implementation - // Unlike the string functions, which actually figure out the function to use with each array, - // here we return either a cast fn or string timestamp translation based on the expression data type - // so we don't have to pay a per-array/batch cost. - BuiltinScalarFunction::ToTimestamp => { - Arc::new(match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Nanosecond, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp, - other => { - return internal_err!( - "Unsupported data type {other:?} for function to_timestamp" - ); - } - }) - } - BuiltinScalarFunction::ToTimestampMillis => { - Arc::new(match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Millisecond, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_millis, - other => { - return internal_err!( - "Unsupported data type {other:?} for function to_timestamp_millis" - ); - } - }) - } - BuiltinScalarFunction::ToTimestampMicros => { - Arc::new(match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Microsecond, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_micros, - other => { - return internal_err!( - "Unsupported data type {other:?} for function to_timestamp_micros" - ); - } - }) - } - BuiltinScalarFunction::ToTimestampSeconds => Arc::new({ - match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Second, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_seconds, - other => { - return internal_err!( - "Unsupported data type {other:?} for function to_timestamp_seconds" - ); - } - } - }), - BuiltinScalarFunction::FromUnixtime => Arc::new({ - match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) => |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Second, None), - None, - ) - }, - other => { - return internal_err!( - "Unsupported data type {other:?} for function from_unixtime" - ); - } - } - }), - BuiltinScalarFunction::ArrowTypeof => { - let input_data_type = input_phy_exprs[0].data_type(input_schema)?; - Arc::new(move |_| { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(format!( - "{input_data_type}" - ))))) - }) - } - BuiltinScalarFunction::Abs => { - let input_data_type = input_phy_exprs[0].data_type(input_schema)?; - let abs_fun = math_expressions::create_abs_function(&input_data_type)?; - Arc::new(move |args| make_scalar_function(abs_fun)(args)) - } - // These don't need args and input schema - _ => create_physical_fun(fun, execution_props)?, - }; + let fun_expr: ScalarFunctionImplementation = + create_physical_fun(fun, execution_props)?; - let monotonicity = get_func_monotonicity(fun); + let monotonicity = fun.monotonicity(); Ok(Arc::new(ScalarFunctionExpr::new( &format!("{fun}"), fun_expr, input_phy_exprs.to_vec(), - &data_type, + data_type, monotonicity, ))) } @@ -332,6 +224,8 @@ where ColumnarValue::Array(a) => Some(a.len()), }); + let is_scalar = len.is_none(); + let inferred_length = len.unwrap_or(1); let args = args .iter() @@ -345,15 +239,16 @@ where }; arg.clone().into_array(expansion_len) }) - .collect::>(); + .collect::>>()?; let result = (inner)(&args); - // maybe back to scalar - if len.is_some() { - result.map(ColumnarValue::Array) + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) } else { - ScalarValue::try_from_array(&result?, 0).map(ColumnarValue::Scalar) + result.map(ColumnarValue::Array) } }) } @@ -365,6 +260,9 @@ pub fn create_physical_fun( ) -> Result { Ok(match fun { // math functions + BuiltinScalarFunction::Abs => { + Arc::new(|args| make_scalar_function(math_expressions::abs_invoke)(args)) + } BuiltinScalarFunction::Acos => Arc::new(math_expressions::acos), BuiltinScalarFunction::Asin => Arc::new(math_expressions::asin), BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan), @@ -431,6 +329,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayAppend => { Arc::new(|args| make_scalar_function(array_expressions::array_append)(args)) } + BuiltinScalarFunction::ArraySort => { + Arc::new(|args| make_scalar_function(array_expressions::array_sort)(args)) + } BuiltinScalarFunction::ArrayConcat => { Arc::new(|args| make_scalar_function(array_expressions::array_concat)(args)) } @@ -449,9 +350,15 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayDims => { Arc::new(|args| make_scalar_function(array_expressions::array_dims)(args)) } + BuiltinScalarFunction::ArrayDistinct => { + Arc::new(|args| make_scalar_function(array_expressions::array_distinct)(args)) + } BuiltinScalarFunction::ArrayElement => { Arc::new(|args| make_scalar_function(array_expressions::array_element)(args)) } + BuiltinScalarFunction::ArrayExcept => { + Arc::new(|args| make_scalar_function(array_expressions::array_except)(args)) + } BuiltinScalarFunction::ArrayLength => { Arc::new(|args| make_scalar_function(array_expressions::array_length)(args)) } @@ -461,6 +368,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayNdims => { Arc::new(|args| make_scalar_function(array_expressions::array_ndims)(args)) } + BuiltinScalarFunction::ArrayPopFront => Arc::new(|args| { + make_scalar_function(array_expressions::array_pop_front)(args) + }), BuiltinScalarFunction::ArrayPopBack => { Arc::new(|args| make_scalar_function(array_expressions::array_pop_back)(args)) } @@ -500,13 +410,21 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayToString => Arc::new(|args| { make_scalar_function(array_expressions::array_to_string)(args) }), + BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| { + make_scalar_function(array_expressions::array_intersect)(args) + }), + BuiltinScalarFunction::Range => { + Arc::new(|args| make_scalar_function(array_expressions::gen_range)(args)) + } BuiltinScalarFunction::Cardinality => { Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args)) } BuiltinScalarFunction::MakeArray => { Arc::new(|args| make_scalar_function(array_expressions::make_array)(args)) } - + BuiltinScalarFunction::ArrayUnion => { + Arc::new(|args| make_scalar_function(array_expressions::array_union)(args)) + } // struct functions BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr), @@ -593,6 +511,24 @@ pub fn create_physical_fun( execution_props.query_execution_start_time, )) } + BuiltinScalarFunction::ToTimestamp => { + Arc::new(datetime_expressions::to_timestamp_invoke) + } + BuiltinScalarFunction::ToTimestampMillis => { + Arc::new(datetime_expressions::to_timestamp_millis_invoke) + } + BuiltinScalarFunction::ToTimestampMicros => { + Arc::new(datetime_expressions::to_timestamp_micros_invoke) + } + BuiltinScalarFunction::ToTimestampNanos => { + Arc::new(datetime_expressions::to_timestamp_nanos_invoke) + } + BuiltinScalarFunction::ToTimestampSeconds => { + Arc::new(datetime_expressions::to_timestamp_seconds_invoke) + } + BuiltinScalarFunction::FromUnixtime => { + Arc::new(datetime_expressions::from_unixtime_invoke) + } BuiltinScalarFunction::InitCap => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::initcap::)(args) @@ -895,21 +831,97 @@ pub fn create_physical_fun( }), BuiltinScalarFunction::Upper => Arc::new(string_expressions::upper), BuiltinScalarFunction::Uuid => Arc::new(string_expressions::uuid), - _ => { - return internal_err!( - "create_physical_fun: Unsupported scalar function {fun:?}" - ); + BuiltinScalarFunction::ArrowTypeof => Arc::new(move |args| { + if args.len() != 1 { + return internal_err!( + "arrow_typeof function requires 1 arguments, got {}", + args.len() + ); + } + + let input_data_type = args[0].data_type(); + Ok(ColumnarValue::Scalar(ScalarValue::from(format!( + "{input_data_type}" + )))) + }), + BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::overlay::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::overlay::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function overlay", + ))), + }), + BuiltinScalarFunction::Levenshtein => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::levenshtein::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::levenshtein::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function levenshtein", + ))), + }) } + BuiltinScalarFunction::SubstrIndex => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + substr_index, + i32, + "substr_index" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + substr_index, + i64, + "substr_index" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function substr_index", + ))), + }) + } + BuiltinScalarFunction::FindInSet => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + find_in_set, + Int32Type, + "find_in_set" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + find_in_set, + Int64Type, + "find_in_set" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function find_in_set", + ))), + }), }) } -/// Monotonicity of the `ScalarFunctionExpr` with respect to its arguments. -/// Each element of this vector corresponds to an argument and indicates whether -/// the function's behavior is monotonic, or non-monotonic/unknown for that argument, namely: -/// - `None` signifies unknown monotonicity or non-monotonicity. -/// - `Some(true)` indicates that the function is monotonically increasing w.r.t. the argument in question. -/// - Some(false) indicates that the function is monotonically decreasing w.r.t. the argument in question. -pub type FuncMonotonicity = Vec>; +#[deprecated( + since = "32.0.0", + note = "Moved to `expr` crate. Please use `BuiltinScalarFunction::monotonicity()` instead" +)] +pub fn get_func_monotonicity(fun: &BuiltinScalarFunction) -> Option { + fun.monotonicity() +} /// Determines a [`ScalarFunctionExpr`]'s monotonicity for the given arguments /// and the function's behavior depending on its arguments. @@ -964,47 +976,6 @@ fn func_order_in_one_dimension( } } -/// This function specifies monotonicity behaviors for built-in scalar functions. -/// The list can be extended, only mathematical and datetime functions are -/// considered for the initial implementation of this feature. -pub fn get_func_monotonicity(fun: &BuiltinScalarFunction) -> Option { - if matches!( - fun, - BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Ceil - | BuiltinScalarFunction::Degrees - | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Factorial - | BuiltinScalarFunction::Floor - | BuiltinScalarFunction::Ln - | BuiltinScalarFunction::Log10 - | BuiltinScalarFunction::Log2 - | BuiltinScalarFunction::Radians - | BuiltinScalarFunction::Round - | BuiltinScalarFunction::Signum - | BuiltinScalarFunction::Sinh - | BuiltinScalarFunction::Sqrt - | BuiltinScalarFunction::Cbrt - | BuiltinScalarFunction::Tanh - | BuiltinScalarFunction::Trunc - | BuiltinScalarFunction::Pi - ) { - Some(vec![Some(true)]) - } else if matches!( - fun, - BuiltinScalarFunction::DateTrunc | BuiltinScalarFunction::DateBin - ) { - Some(vec![None, Some(true)]) - } else if *fun == BuiltinScalarFunction::Log { - Some(vec![Some(true), Some(false)]) - } else { - None - } -} - #[cfg(test)] mod tests { use super::*; @@ -1051,7 +1022,7 @@ mod tests { match expected { Ok(expected) => { let result = expr.evaluate(&batch)?; - let result = result.into_array(batch.num_rows()); + let result = result.into_array(batch.num_rows()).expect("Failed to convert to array"); let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); // value is correct @@ -2965,13 +2936,8 @@ mod tests { "Builtin scalar function {fun} does not support empty arguments" ); } - Err(DataFusionError::Plan(err)) => { - if !err - .contains("No function matches the given name and argument types") - { - return plan_err!( - "Builtin scalar function {fun} didn't got the right error message with empty arguments"); - } + Err(DataFusionError::Plan(_)) => { + // Continue the loop } Err(..) => { return internal_err!( @@ -3025,7 +2991,10 @@ mod tests { // evaluate works let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // downcast works let result = as_list_array(&result)?; @@ -3064,7 +3033,10 @@ mod tests { // evaluate works let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // downcast works let result = as_list_array(&result)?; @@ -3136,8 +3108,11 @@ mod tests { let adapter_func = make_scalar_function(dummy_function); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 5]); @@ -3149,8 +3124,11 @@ mod tests { let adapter_func = make_scalar_function_with_hints(dummy_function, vec![]); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 5]); @@ -3165,8 +3143,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 1]); @@ -3175,8 +3156,11 @@ mod tests { #[test] fn test_make_scalar_function_with_hints_on_arrays() -> Result<()> { - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let adapter_func = make_scalar_function_with_hints( dummy_function, vec![Hint::Pad, Hint::AcceptsSingular], @@ -3196,8 +3180,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[ array_arg, scalar_arg.clone(), @@ -3216,8 +3203,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[ array_arg.clone(), scalar_arg.clone(), @@ -3244,8 +3234,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 1]); diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 0a090636dc4b..5064ad8d5c48 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -24,15 +24,13 @@ use std::sync::Arc; use super::utils::{ convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op, }; -use super::IntervalBound; use crate::expressions::Literal; -use crate::intervals::interval_aritmetic::{apply_operator, Interval}; use crate::utils::{build_dag, ExprTreeNode}; use crate::PhysicalExpr; -use arrow_schema::DataType; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::type_coercion::binary::get_result_type; +use arrow_schema::{DataType, Schema}; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::{apply_operator, satisfy_greater, Interval}; use datafusion_expr::Operator; use petgraph::graph::NodeIndex; @@ -148,7 +146,7 @@ pub enum PropagationResult { } /// This is a node in the DAEG; it encapsulates a reference to the actual -/// [PhysicalExpr] as well as an interval containing expression bounds. +/// [`PhysicalExpr`] as well as an interval containing expression bounds. #[derive(Clone, Debug)] pub struct ExprIntervalGraphNode { expr: Arc, @@ -163,11 +161,9 @@ impl Display for ExprIntervalGraphNode { impl ExprIntervalGraphNode { /// Constructs a new DAEG node with an [-∞, ∞] range. - pub fn new(expr: Arc) -> Self { - ExprIntervalGraphNode { - expr, - interval: Interval::default(), - } + pub fn new_unbounded(expr: Arc, dt: &DataType) -> Result { + Interval::make_unbounded(dt) + .map(|interval| ExprIntervalGraphNode { expr, interval }) } /// Constructs a new DAEG node with the given range. @@ -180,26 +176,24 @@ impl ExprIntervalGraphNode { &self.interval } - /// This function creates a DAEG node from Datafusion's [ExprTreeNode] + /// This function creates a DAEG node from Datafusion's [`ExprTreeNode`] /// object. Literals are created with definite, singleton intervals while /// any other expression starts with an indefinite interval ([-∞, ∞]). - pub fn make_node(node: &ExprTreeNode) -> ExprIntervalGraphNode { + pub fn make_node(node: &ExprTreeNode, schema: &Schema) -> Result { let expr = node.expression().clone(); if let Some(literal) = expr.as_any().downcast_ref::() { let value = literal.value(); - let interval = Interval::new( - IntervalBound::new_closed(value.clone()), - IntervalBound::new_closed(value.clone()), - ); - ExprIntervalGraphNode::new_with_interval(expr, interval) + Interval::try_new(value.clone(), value.clone()) + .map(|interval| Self::new_with_interval(expr, interval)) } else { - ExprIntervalGraphNode::new(expr) + expr.data_type(schema) + .and_then(|dt| Self::new_unbounded(expr, &dt)) } } } impl PartialEq for ExprIntervalGraphNode { - fn eq(&self, other: &ExprIntervalGraphNode) -> bool { + fn eq(&self, other: &Self) -> bool { self.expr.eq(&other.expr) } } @@ -216,16 +210,23 @@ impl PartialEq for ExprIntervalGraphNode { /// - For minus operation, specifically, we would first do /// - [xL, xU] <- ([yL, yU] + [pL, pU]) ∩ [xL, xU], and then /// - [yL, yU] <- ([xL, xU] - [pL, pU]) ∩ [yL, yU]. +/// - For multiplication operation, specifically, we would first do +/// - [xL, xU] <- ([pL, pU] / [yL, yU]) ∩ [xL, xU], and then +/// - [yL, yU] <- ([pL, pU] / [xL, xU]) ∩ [yL, yU]. +/// - For division operation, specifically, we would first do +/// - [xL, xU] <- ([yL, yU] * [pL, pU]) ∩ [xL, xU], and then +/// - [yL, yU] <- ([xL, xU] / [pL, pU]) ∩ [yL, yU]. pub fn propagate_arithmetic( op: &Operator, parent: &Interval, left_child: &Interval, right_child: &Interval, -) -> Result<(Option, Option)> { - let inverse_op = get_inverse_op(*op); - match (left_child.get_datatype()?, right_child.get_datatype()?) { - // If we have a child whose type is a time interval (i.e. DataType::Interval), we need special handling - // since timestamp differencing results in a Duration type. +) -> Result> { + let inverse_op = get_inverse_op(*op)?; + match (left_child.data_type(), right_child.data_type()) { + // If we have a child whose type is a time interval (i.e. DataType::Interval), + // we need special handling since timestamp differencing results in a + // Duration type. (DataType::Timestamp(..), DataType::Interval(_)) => { propagate_time_interval_at_right( left_child, @@ -250,87 +251,109 @@ pub fn propagate_arithmetic( .intersect(left_child)? { // Left is feasible: - Some(value) => { + Some(value) => Ok( // Propagate to the right using the new left. - let right = - propagate_right(&value, parent, right_child, op, &inverse_op)?; - - // Return intervals for both children: - Ok((Some(value), right)) - } + propagate_right(&value, parent, right_child, op, &inverse_op)? + .map(|right| (value, right)), + ), // If the left child is infeasible, short-circuit. - None => Ok((None, None)), + None => Ok(None), } } } } -/// This function provides a target parent interval for comparison operators. -/// If we have expression > 0, expression must have the range (0, ∞). -/// If we have expression >= 0, expression must have the range [0, ∞). -/// If we have expression < 0, expression must have the range (-∞, 0). -/// If we have expression <= 0, expression must have the range (-∞, 0]. -fn comparison_operator_target( - left_datatype: &DataType, - op: &Operator, - right_datatype: &DataType, -) -> Result { - let datatype = get_result_type(left_datatype, &Operator::Minus, right_datatype)?; - let unbounded = IntervalBound::make_unbounded(&datatype)?; - let zero = ScalarValue::new_zero(&datatype)?; - Ok(match *op { - Operator::GtEq => Interval::new(IntervalBound::new_closed(zero), unbounded), - Operator::Gt => Interval::new(IntervalBound::new_open(zero), unbounded), - Operator::LtEq => Interval::new(unbounded, IntervalBound::new_closed(zero)), - Operator::Lt => Interval::new(unbounded, IntervalBound::new_open(zero)), - Operator::Eq => Interval::new( - IntervalBound::new_closed(zero.clone()), - IntervalBound::new_closed(zero), - ), - _ => unreachable!(), - }) -} - -/// This function propagates constraints arising from comparison operators. -/// The main idea is that we can analyze an inequality like x > y through the -/// equivalent inequality x - y > 0. Assuming that x and y has ranges [xL, xU] -/// and [yL, yU], we simply apply constraint propagation across [xL, xU], -/// [yL, yH] and [0, ∞]. Specifically, we would first do -/// - [xL, xU] <- ([yL, yU] + [0, ∞]) ∩ [xL, xU], and then -/// - [yL, yU] <- ([xL, xU] - [0, ∞]) ∩ [yL, yU]. +/// This function refines intervals `left_child` and `right_child` by applying +/// comparison propagation through `parent` via operation. The main idea is +/// that we can shrink ranges of variables x and y using parent interval p. +/// Two intervals can be ordered in 6 ways for a Gt `>` operator: +/// ```text +/// (1): Infeasible, short-circuit +/// left: | ================ | +/// right: | ======================== | +/// +/// (2): Update both interval +/// left: | ====================== | +/// right: | ====================== | +/// | +/// V +/// left: | ======= | +/// right: | ======= | +/// +/// (3): Update left interval +/// left: | ============================== | +/// right: | ========== | +/// | +/// V +/// left: | ===================== | +/// right: | ========== | +/// +/// (4): Update right interval +/// left: | ========== | +/// right: | =========================== | +/// | +/// V +/// left: | ========== | +/// right | ================== | +/// +/// (5): No change +/// left: | ============================ | +/// right: | =================== | +/// +/// (6): No change +/// left: | ==================== | +/// right: | =============== | +/// +/// -inf --------------------------------------------------------------- +inf +/// ``` pub fn propagate_comparison( op: &Operator, + parent: &Interval, left_child: &Interval, right_child: &Interval, -) -> Result<(Option, Option)> { - let left_type = left_child.get_datatype()?; - let right_type = right_child.get_datatype()?; - let parent = comparison_operator_target(&left_type, op, &right_type)?; - match (&left_type, &right_type) { - // We can not compare a Duration type with a time interval type - // without a reference timestamp unless the latter has a zero month field. - (DataType::Interval(_), DataType::Duration(_)) => { - propagate_comparison_to_time_interval_at_left( - left_child, - &parent, - right_child, - ) +) -> Result> { + if parent == &Interval::CERTAINLY_TRUE { + match op { + Operator::Eq => left_child.intersect(right_child).map(|result| { + result.map(|intersection| (intersection.clone(), intersection)) + }), + Operator::Gt => satisfy_greater(left_child, right_child, true), + Operator::GtEq => satisfy_greater(left_child, right_child, false), + Operator::Lt => satisfy_greater(right_child, left_child, true) + .map(|t| t.map(reverse_tuple)), + Operator::LtEq => satisfy_greater(right_child, left_child, false) + .map(|t| t.map(reverse_tuple)), + _ => internal_err!( + "The operator must be a comparison operator to propagate intervals" + ), } - (DataType::Duration(_), DataType::Interval(_)) => { - propagate_comparison_to_time_interval_at_left( - left_child, - &parent, - right_child, - ) + } else if parent == &Interval::CERTAINLY_FALSE { + match op { + Operator::Eq => { + // TODO: Propagation is not possible until we support interval sets. + Ok(None) + } + Operator::Gt => satisfy_greater(right_child, left_child, false), + Operator::GtEq => satisfy_greater(right_child, left_child, true), + Operator::Lt => satisfy_greater(left_child, right_child, false) + .map(|t| t.map(reverse_tuple)), + Operator::LtEq => satisfy_greater(left_child, right_child, true) + .map(|t| t.map(reverse_tuple)), + _ => internal_err!( + "The operator must be a comparison operator to propagate intervals" + ), } - _ => propagate_arithmetic(&Operator::Minus, &parent, left_child, right_child), + } else { + // Uncertainty cannot change any end-point of the intervals. + Ok(None) } } impl ExprIntervalGraph { - pub fn try_new(expr: Arc) -> Result { + pub fn try_new(expr: Arc, schema: &Schema) -> Result { // Build the full graph: - let (root, graph) = build_dag(expr, &ExprIntervalGraphNode::make_node)?; + let (root, graph) = + build_dag(expr, &|node| ExprIntervalGraphNode::make_node(node, schema))?; Ok(Self { graph, root }) } @@ -383,7 +406,7 @@ impl ExprIntervalGraph { // // ``` - /// This function associates stable node indices with [PhysicalExpr]s so + /// This function associates stable node indices with [`PhysicalExpr`]s so /// that we can match `Arc` and NodeIndex objects during /// membership tests. pub fn gather_node_indices( @@ -437,6 +460,33 @@ impl ExprIntervalGraph { nodes } + /// Updates intervals for all expressions in the DAEG by successive + /// bottom-up and top-down traversals. + pub fn update_ranges( + &mut self, + leaf_bounds: &mut [(usize, Interval)], + given_range: Interval, + ) -> Result { + self.assign_intervals(leaf_bounds); + let bounds = self.evaluate_bounds()?; + // There are three possible cases to consider: + // (1) given_range ⊇ bounds => Nothing to propagate + // (2) ∅ ⊂ (given_range ∩ bounds) ⊂ bounds => Can propagate + // (3) Disjoint sets => Infeasible + if given_range.contains(bounds)? == Interval::CERTAINLY_TRUE { + // First case: + Ok(PropagationResult::CannotPropagate) + } else if bounds.contains(&given_range)? != Interval::CERTAINLY_FALSE { + // Second case: + let result = self.propagate_constraints(given_range); + self.update_intervals(leaf_bounds); + result + } else { + // Third case: + Ok(PropagationResult::Infeasible) + } + } + /// This function assigns given ranges to expressions in the DAEG. /// The argument `assignments` associates indices of sought expressions /// with their corresponding new ranges. @@ -466,34 +516,43 @@ impl ExprIntervalGraph { /// # Examples /// /// ``` - /// use std::sync::Arc; - /// use datafusion_common::ScalarValue; - /// use datafusion_expr::Operator; - /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; - /// use datafusion_physical_expr::intervals::{Interval, IntervalBound, ExprIntervalGraph}; - /// use datafusion_physical_expr::PhysicalExpr; - /// let expr = Arc::new(BinaryExpr::new( - /// Arc::new(Column::new("gnz", 0)), - /// Operator::Plus, - /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), - /// )); - /// let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); - /// // Do it once, while constructing. - /// let node_indices = graph + /// use arrow::datatypes::DataType; + /// use arrow::datatypes::Field; + /// use arrow::datatypes::Schema; + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::interval_arithmetic::Interval; + /// use datafusion_expr::Operator; + /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; + /// use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; + /// use datafusion_physical_expr::PhysicalExpr; + /// use std::sync::Arc; + /// + /// let expr = Arc::new(BinaryExpr::new( + /// Arc::new(Column::new("gnz", 0)), + /// Operator::Plus, + /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + /// )); + /// + /// let schema = Schema::new(vec![Field::new("gnz".to_string(), DataType::Int32, true)]); + /// + /// let mut graph = ExprIntervalGraph::try_new(expr, &schema).unwrap(); + /// // Do it once, while constructing. + /// let node_indices = graph /// .gather_node_indices(&[Arc::new(Column::new("gnz", 0))]); - /// let left_index = node_indices.get(0).unwrap().1; - /// // Provide intervals for leaf variables (here, there is only one). - /// let intervals = vec![( + /// let left_index = node_indices.get(0).unwrap().1; + /// + /// // Provide intervals for leaf variables (here, there is only one). + /// let intervals = vec![( /// left_index, - /// Interval::make(Some(10), Some(20), (true, true)), - /// )]; - /// // Evaluate bounds for the composite expression: - /// graph.assign_intervals(&intervals); - /// assert_eq!( - /// graph.evaluate_bounds().unwrap(), - /// &Interval::make(Some(20), Some(30), (true, true)), - /// ) + /// Interval::make(Some(10), Some(20)).unwrap(), + /// )]; /// + /// // Evaluate bounds for the composite expression: + /// graph.assign_intervals(&intervals); + /// assert_eq!( + /// graph.evaluate_bounds().unwrap(), + /// &Interval::make(Some(20), Some(30)).unwrap(), + /// ) /// ``` pub fn evaluate_bounds(&mut self) -> Result<&Interval> { let mut dfs = DfsPostOrder::new(&self.graph, self.root); @@ -505,7 +564,7 @@ impl ExprIntervalGraph { // If the current expression is a leaf, its interval should already // be set externally, just continue with the evaluation procedure: if !children_intervals.is_empty() { - // Reverse to align with [PhysicalExpr]'s children: + // Reverse to align with `PhysicalExpr`'s children: children_intervals.reverse(); self.graph[node].interval = self.graph[node].expr.evaluate_bounds(&children_intervals)?; @@ -516,8 +575,19 @@ impl ExprIntervalGraph { /// Updates/shrinks bounds for leaf expressions using interval arithmetic /// via a top-down traversal. - fn propagate_constraints(&mut self) -> Result { + fn propagate_constraints( + &mut self, + given_range: Interval, + ) -> Result { let mut bfs = Bfs::new(&self.graph, self.root); + + // Adjust the root node with the given range: + if let Some(interval) = self.graph[self.root].interval.intersect(given_range)? { + self.graph[self.root].interval = interval; + } else { + return Ok(PropagationResult::Infeasible); + } + while let Some(node) = bfs.next(&self.graph) { let neighbors = self.graph.neighbors_directed(node, Outgoing); let mut children = neighbors.collect::>(); @@ -526,7 +596,7 @@ impl ExprIntervalGraph { if children.is_empty() { continue; } - // Reverse to align with [PhysicalExpr]'s children: + // Reverse to align with `PhysicalExpr`'s children: children.reverse(); let children_intervals = children .iter() @@ -536,163 +606,132 @@ impl ExprIntervalGraph { let propagated_intervals = self.graph[node] .expr .propagate_constraints(node_interval, &children_intervals)?; - for (child, interval) in children.into_iter().zip(propagated_intervals) { - if let Some(interval) = interval { + if let Some(propagated_intervals) = propagated_intervals { + for (child, interval) in children.into_iter().zip(propagated_intervals) { self.graph[child].interval = interval; - } else { - // The constraint is infeasible, report: - return Ok(PropagationResult::Infeasible); } + } else { + // The constraint is infeasible, report: + return Ok(PropagationResult::Infeasible); } } Ok(PropagationResult::Success) } - /// Updates intervals for all expressions in the DAEG by successive - /// bottom-up and top-down traversals. - pub fn update_ranges( - &mut self, - leaf_bounds: &mut [(usize, Interval)], - ) -> Result { - self.assign_intervals(leaf_bounds); - let bounds = self.evaluate_bounds()?; - if bounds == &Interval::CERTAINLY_FALSE { - Ok(PropagationResult::Infeasible) - } else if bounds == &Interval::UNCERTAIN { - let result = self.propagate_constraints(); - self.update_intervals(leaf_bounds); - result - } else { - Ok(PropagationResult::CannotPropagate) - } - } - /// Returns the interval associated with the node at the given `index`. pub fn get_interval(&self, index: usize) -> Interval { self.graph[NodeIndex::new(index)].interval.clone() } } -/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], if there exists a `timestamp - timestamp` -/// operation, the result would be of type `Duration`. However, we may encounter a situation where a time interval -/// is involved in an arithmetic operation with a `Duration` type. This function offers special handling for such cases, -/// where the time interval resides on the left side of the operation. +/// This is a subfunction of the `propagate_arithmetic` function that propagates to the right child. +fn propagate_right( + left: &Interval, + parent: &Interval, + right: &Interval, + op: &Operator, + inverse_op: &Operator, +) -> Result> { + match op { + Operator::Minus => apply_operator(op, left, parent), + Operator::Plus => apply_operator(inverse_op, parent, left), + Operator::Divide => apply_operator(op, left, parent), + Operator::Multiply => apply_operator(inverse_op, parent, left), + _ => internal_err!("Interval arithmetic does not support the operator {}", op), + }? + .intersect(right) +} + +/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], +/// if there exists a `timestamp - timestamp` operation, the result would be +/// of type `Duration`. However, we may encounter a situation where a time interval +/// is involved in an arithmetic operation with a `Duration` type. This function +/// offers special handling for such cases, where the time interval resides on +/// the left side of the operation. fn propagate_time_interval_at_left( left_child: &Interval, right_child: &Interval, parent: &Interval, op: &Operator, inverse_op: &Operator, -) -> Result<(Option, Option)> { +) -> Result> { // We check if the child's time interval(s) has a non-zero month or day field(s). // If so, we return it as is without propagating. Otherwise, we first convert - // the time intervals to the Duration type, then propagate, and then convert the bounds to time intervals again. - if let Some(duration) = convert_interval_type_to_duration(left_child) { + // the time intervals to the `Duration` type, then propagate, and then convert + // the bounds to time intervals again. + let result = if let Some(duration) = convert_interval_type_to_duration(left_child) { match apply_operator(inverse_op, parent, right_child)?.intersect(duration)? { Some(value) => { + let left = convert_duration_type_to_interval(&value); let right = propagate_right(&value, parent, right_child, op, inverse_op)?; - let new_interval = convert_duration_type_to_interval(&value); - Ok((new_interval, right)) + match (left, right) { + (Some(left), Some(right)) => Some((left, right)), + _ => None, + } } - None => Ok((None, None)), + None => None, } } else { - let right = propagate_right(left_child, parent, right_child, op, inverse_op)?; - Ok((Some(left_child.clone()), right)) - } + propagate_right(left_child, parent, right_child, op, inverse_op)? + .map(|right| (left_child.clone(), right)) + }; + Ok(result) } -/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], if there exists a `timestamp - timestamp` -/// operation, the result would be of type `Duration`. However, we may encounter a situation where a time interval -/// is involved in an arithmetic operation with a `Duration` type. This function offers special handling for such cases, -/// where the time interval resides on the right side of the operation. +/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], +/// if there exists a `timestamp - timestamp` operation, the result would be +/// of type `Duration`. However, we may encounter a situation where a time interval +/// is involved in an arithmetic operation with a `Duration` type. This function +/// offers special handling for such cases, where the time interval resides on +/// the right side of the operation. fn propagate_time_interval_at_right( left_child: &Interval, right_child: &Interval, parent: &Interval, op: &Operator, inverse_op: &Operator, -) -> Result<(Option, Option)> { +) -> Result> { // We check if the child's time interval(s) has a non-zero month or day field(s). // If so, we return it as is without propagating. Otherwise, we first convert - // the time intervals to the Duration type, then propagate, and then convert the bounds to time intervals again. - if let Some(duration) = convert_interval_type_to_duration(right_child) { + // the time intervals to the `Duration` type, then propagate, and then convert + // the bounds to time intervals again. + let result = if let Some(duration) = convert_interval_type_to_duration(right_child) { match apply_operator(inverse_op, parent, &duration)?.intersect(left_child)? { Some(value) => { - let right = - propagate_right(left_child, parent, &duration, op, inverse_op)?; - let right = - right.and_then(|right| convert_duration_type_to_interval(&right)); - Ok((Some(value), right)) + propagate_right(left_child, parent, &duration, op, inverse_op)? + .and_then(|right| convert_duration_type_to_interval(&right)) + .map(|right| (value, right)) } - None => Ok((None, None)), + None => None, } } else { - match apply_operator(inverse_op, parent, right_child)?.intersect(left_child)? { - Some(value) => Ok((Some(value), Some(right_child.clone()))), - None => Ok((None, None)), - } - } -} - -/// This is a subfunction of the `propagate_arithmetic` function that propagates to the right child. -fn propagate_right( - left: &Interval, - parent: &Interval, - right: &Interval, - op: &Operator, - inverse_op: &Operator, -) -> Result> { - match op { - Operator::Minus => apply_operator(op, left, parent), - Operator::Plus => apply_operator(inverse_op, parent, left), - _ => unreachable!(), - }? - .intersect(right) -} - -/// Converts the `time interval` (as the left child) to duration, then performs the propagation rule for comparison operators. -pub fn propagate_comparison_to_time_interval_at_left( - left_child: &Interval, - parent: &Interval, - right_child: &Interval, -) -> Result<(Option, Option)> { - if let Some(converted) = convert_interval_type_to_duration(left_child) { - propagate_arithmetic(&Operator::Minus, parent, &converted, right_child) - } else { - Err(DataFusionError::Internal( - "Interval type has a non-zero month field, cannot compare with a Duration type".to_string(), - )) - } + apply_operator(inverse_op, parent, right_child)? + .intersect(left_child)? + .map(|value| (value, right_child.clone())) + }; + Ok(result) } -/// Converts the `time interval` (as the right child) to duration, then performs the propagation rule for comparison operators. -pub fn propagate_comparison_to_time_interval_at_right( - left_child: &Interval, - parent: &Interval, - right_child: &Interval, -) -> Result<(Option, Option)> { - if let Some(converted) = convert_interval_type_to_duration(right_child) { - propagate_arithmetic(&Operator::Minus, parent, left_child, &converted) - } else { - Err(DataFusionError::Internal( - "Interval type has a non-zero month field, cannot compare with a Duration type".to_string(), - )) - } +fn reverse_tuple((first, second): (T, U)) -> (U, T) { + (second, first) } #[cfg(test)] mod tests { use super::*; - use itertools::Itertools; - use crate::expressions::{BinaryExpr, Column}; use crate::intervals::test_utils::gen_conjunctive_numerical_expr; + + use arrow::datatypes::TimeUnit; + use arrow_schema::{DataType, Field}; use datafusion_common::ScalarValue; + + use itertools::Itertools; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use rstest::*; + #[allow(clippy::too_many_arguments)] fn experiment( expr: Arc, exprs_with_interval: (Arc, Arc), @@ -701,6 +740,7 @@ mod tests { left_expected: Interval, right_expected: Interval, result: PropagationResult, + schema: &Schema, ) -> Result<()> { let col_stats = vec![ (exprs_with_interval.0.clone(), left_interval), @@ -710,7 +750,7 @@ mod tests { (exprs_with_interval.0.clone(), left_expected), (exprs_with_interval.1.clone(), right_expected), ]; - let mut graph = ExprIntervalGraph::try_new(expr)?; + let mut graph = ExprIntervalGraph::try_new(expr, schema)?; let expr_indexes = graph .gather_node_indices(&col_stats.iter().map(|(e, _)| e.clone()).collect_vec()); @@ -725,14 +765,37 @@ mod tests { .map(|((_, interval), (_, index))| (*index, interval.clone())) .collect_vec(); - let exp_result = graph.update_ranges(&mut col_stat_nodes[..])?; + let exp_result = + graph.update_ranges(&mut col_stat_nodes[..], Interval::CERTAINLY_TRUE)?; assert_eq!(exp_result, result); col_stat_nodes.iter().zip(expected_nodes.iter()).for_each( |((_, calculated_interval_node), (_, expected))| { // NOTE: These randomized tests only check for conservative containment, // not openness/closedness of endpoints. - assert!(calculated_interval_node.lower.value <= expected.lower.value); - assert!(calculated_interval_node.upper.value >= expected.upper.value); + + // Calculated bounds are relaxed by 1 to cover all strict and + // and non-strict comparison cases since we have only closed bounds. + let one = ScalarValue::new_one(&expected.data_type()).unwrap(); + assert!( + calculated_interval_node.lower() + <= &expected.lower().add(&one).unwrap(), + "{}", + format!( + "Calculated {} must be less than or equal {}", + calculated_interval_node.lower(), + expected.lower() + ) + ); + assert!( + calculated_interval_node.upper() + >= &expected.upper().sub(&one).unwrap(), + "{}", + format!( + "Calculated {} must be greater than or equal {}", + calculated_interval_node.upper(), + expected.upper() + ) + ); }, ); Ok(()) @@ -772,12 +835,24 @@ mod tests { experiment( expr, - (left_col, right_col), - Interval::make(left_given.0, left_given.1, (true, true)), - Interval::make(right_given.0, right_given.1, (true, true)), - Interval::make(left_expected.0, left_expected.1, (true, true)), - Interval::make(right_expected.0, right_expected.1, (true, true)), + (left_col.clone(), right_col.clone()), + Interval::make(left_given.0, left_given.1).unwrap(), + Interval::make(right_given.0, right_given.1).unwrap(), + Interval::make(left_expected.0, left_expected.1).unwrap(), + Interval::make(right_expected.0, right_expected.1).unwrap(), PropagationResult::Success, + &Schema::new(vec![ + Field::new( + left_col.as_any().downcast_ref::().unwrap().name(), + DataType::$SCALAR, + true, + ), + Field::new( + right_col.as_any().downcast_ref::().unwrap().name(), + DataType::$SCALAR, + true, + ), + ]), ) } }; @@ -801,12 +876,24 @@ mod tests { let expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, right_col.clone())); experiment( expr, - (left_col, right_col), - Interval::make(Some(10), Some(20), (true, true)), - Interval::make(Some(100), None, (true, true)), - Interval::make(Some(10), Some(20), (true, true)), - Interval::make(Some(100), None, (true, true)), + (left_col.clone(), right_col.clone()), + Interval::make(Some(10_i32), Some(20_i32))?, + Interval::make(Some(100), None)?, + Interval::make(Some(10), Some(20))?, + Interval::make(Some(100), None)?, PropagationResult::Infeasible, + &Schema::new(vec![ + Field::new( + left_col.as_any().downcast_ref::().unwrap().name(), + DataType::Int32, + true, + ), + Field::new( + right_col.as_any().downcast_ref::().unwrap().name(), + DataType::Int32, + true, + ), + ]), ) } @@ -1111,7 +1198,14 @@ mod tests { Arc::new(Column::new("b", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1150,7 +1244,16 @@ mod tests { Arc::new(Column::new("z", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + Field::new("z", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1189,7 +1292,15 @@ mod tests { Arc::new(Column::new("z", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("z", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1212,9 +1323,9 @@ mod tests { fn test_gather_node_indices_cannot_provide() -> Result<()> { // Expression: a@0 + 1 + b@1 > y@0 - z@1 -> provide a@0 + b@1 // TODO: We expect nodes a@0 and b@1 to be pruned, and intervals to be provided from the a@0 + b@1 node. - // However, we do not have an exact node for a@0 + b@1 due to the binary tree structure of the expressions. - // Pruning and interval providing for BinaryExpr expressions are more challenging without exact matches. - // Currently, we only support exact matches for BinaryExprs, but we plan to extend support beyond exact matches in the future. + // However, we do not have an exact node for a@0 + b@1 due to the binary tree structure of the expressions. + // Pruning and interval providing for BinaryExpr expressions are more challenging without exact matches. + // Currently, we only support exact matches for BinaryExprs, but we plan to extend support beyond exact matches in the future. let left_expr = Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1231,7 +1342,16 @@ mod tests { Arc::new(Column::new("z", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + Field::new("z", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1256,80 +1376,51 @@ mod tests { Operator::Plus, Arc::new(Literal::new(ScalarValue::new_interval_mdn(0, 1, 321))), ); - let parent = Interval::new( - IntervalBound::new( - // 15.10.2020 - 10:11:12.000_000_321 AM - ScalarValue::TimestampNanosecond(Some(1_602_756_672_000_000_321), None), - false, - ), - IntervalBound::new( - // 16.10.2020 - 10:11:12.000_000_321 AM - ScalarValue::TimestampNanosecond(Some(1_602_843_072_000_000_321), None), - false, - ), - ); - let left_child = Interval::new( - IntervalBound::new( - // 10.10.2020 - 10:11:12 AM - ScalarValue::TimestampNanosecond(Some(1_602_324_672_000_000_000), None), - false, - ), - IntervalBound::new( - // 20.10.2020 - 10:11:12 AM - ScalarValue::TimestampNanosecond(Some(1_603_188_672_000_000_000), None), - false, - ), - ); - let right_child = Interval::new( - IntervalBound::new( - // 1 day 321 ns - ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), - false, - ), - IntervalBound::new( - // 1 day 321 ns - ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), - false, - ), - ); + let parent = Interval::try_new( + // 15.10.2020 - 10:11:12.000_000_321 AM + ScalarValue::TimestampNanosecond(Some(1_602_756_672_000_000_321), None), + // 16.10.2020 - 10:11:12.000_000_321 AM + ScalarValue::TimestampNanosecond(Some(1_602_843_072_000_000_321), None), + )?; + let left_child = Interval::try_new( + // 10.10.2020 - 10:11:12 AM + ScalarValue::TimestampNanosecond(Some(1_602_324_672_000_000_000), None), + // 20.10.2020 - 10:11:12 AM + ScalarValue::TimestampNanosecond(Some(1_603_188_672_000_000_000), None), + )?; + let right_child = Interval::try_new( + // 1 day 321 ns + ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), + // 1 day 321 ns + ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), + )?; let children = vec![&left_child, &right_child]; - let result = expression.propagate_constraints(&parent, &children)?; + let result = expression + .propagate_constraints(&parent, &children)? + .unwrap(); assert_eq!( - Some(Interval::new( - // 14.10.2020 - 10:11:12 AM - IntervalBound::new( + vec![ + Interval::try_new( + // 14.10.2020 - 10:11:12 AM ScalarValue::TimestampNanosecond( Some(1_602_670_272_000_000_000), None ), - false, - ), - // 15.10.2020 - 10:11:12 AM - IntervalBound::new( + // 15.10.2020 - 10:11:12 AM ScalarValue::TimestampNanosecond( Some(1_602_756_672_000_000_000), None ), - false, - ), - )), - result[0] - ); - assert_eq!( - Some(Interval::new( - // 1 day 321 ns in Duration type - IntervalBound::new( + )?, + Interval::try_new( + // 1 day 321 ns in Duration type ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), - false, - ), - // 1 day 321 ns in Duration type - IntervalBound::new( + // 1 day 321 ns in Duration type ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), - false, - ), - )), - result[1] + )? + ], + result ); Ok(()) @@ -1342,76 +1433,216 @@ mod tests { Operator::Plus, Arc::new(Column::new("ts_column", 0)), ); - let parent = Interval::new( - IntervalBound::new( - // 15.10.2020 - 10:11:12 AM - ScalarValue::TimestampMillisecond(Some(1_602_756_672_000), None), - false, - ), - IntervalBound::new( - // 16.10.2020 - 10:11:12 AM - ScalarValue::TimestampMillisecond(Some(1_602_843_072_000), None), - false, - ), - ); - let right_child = Interval::new( - IntervalBound::new( - // 10.10.2020 - 10:11:12 AM - ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), - false, - ), - IntervalBound::new( - // 20.10.2020 - 10:11:12 AM - ScalarValue::TimestampMillisecond(Some(1_603_188_672_000), None), - false, - ), - ); - let left_child = Interval::new( - IntervalBound::new( - // 2 days - ScalarValue::IntervalDayTime(Some(172_800_000)), - false, - ), - IntervalBound::new( - // 10 days - ScalarValue::IntervalDayTime(Some(864_000_000)), - false, - ), - ); + let parent = Interval::try_new( + // 15.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_756_672_000), None), + // 16.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_843_072_000), None), + )?; + let right_child = Interval::try_new( + // 10.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), + // 20.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_603_188_672_000), None), + )?; + let left_child = Interval::try_new( + // 2 days + ScalarValue::IntervalDayTime(Some(172_800_000)), + // 10 days + ScalarValue::IntervalDayTime(Some(864_000_000)), + )?; let children = vec![&left_child, &right_child]; - let result = expression.propagate_constraints(&parent, &children)?; + let result = expression + .propagate_constraints(&parent, &children)? + .unwrap(); assert_eq!( - Some(Interval::new( - // 10.10.2020 - 10:11:12 AM - IntervalBound::new( - ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), - false, - ), - // 14.10.2020 - 10:11:12 AM - IntervalBound::new( - ScalarValue::TimestampMillisecond(Some(1_602_670_272_000), None), - false, - ) - )), - result[1] - ); - assert_eq!( - Some(Interval::new( - IntervalBound::new( + vec![ + Interval::try_new( // 2 days ScalarValue::IntervalDayTime(Some(172_800_000)), - false, - ), - IntervalBound::new( // 6 days ScalarValue::IntervalDayTime(Some(518_400_000)), - false, - ), - )), - result[0] + )?, + Interval::try_new( + // 10.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), + // 14.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_670_272_000), None), + )? + ], + result + ); + + Ok(()) + } + + #[test] + fn test_propagate_comparison() -> Result<()> { + // In the examples below: + // `left` is unbounded: [?, ?], + // `right` is known to be [1000,1000] + // so `left` < `right` results in no new knowledge of `right` but knowing that `left` is now < 1000:` [?, 999] + let left = Interval::make_unbounded(&DataType::Int64)?; + let right = Interval::make(Some(1000_i64), Some(1000_i64))?; + assert_eq!( + (Some(( + Interval::make(None, Some(999_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ))), + propagate_comparison( + &Operator::Lt, + &Interval::CERTAINLY_TRUE, + &left, + &right + )? ); + let left = + Interval::make_unbounded(&DataType::Timestamp(TimeUnit::Nanosecond, None))?; + let right = Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), None), + ScalarValue::TimestampNanosecond(Some(1000), None), + )?; + assert_eq!( + (Some(( + Interval::try_new( + ScalarValue::try_from(&DataType::Timestamp( + TimeUnit::Nanosecond, + None + )) + .unwrap(), + ScalarValue::TimestampNanosecond(Some(999), None), + )?, + Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), None), + ScalarValue::TimestampNanosecond(Some(1000), None), + )? + ))), + propagate_comparison( + &Operator::Lt, + &Interval::CERTAINLY_TRUE, + &left, + &right + )? + ); + + let left = Interval::make_unbounded(&DataType::Timestamp( + TimeUnit::Nanosecond, + Some("+05:00".into()), + ))?; + let right = Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + )?; + assert_eq!( + (Some(( + Interval::try_new( + ScalarValue::try_from(&DataType::Timestamp( + TimeUnit::Nanosecond, + Some("+05:00".into()), + )) + .unwrap(), + ScalarValue::TimestampNanosecond(Some(999), Some("+05:00".into())), + )?, + Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + )? + ))), + propagate_comparison( + &Operator::Lt, + &Interval::CERTAINLY_TRUE, + &left, + &right + )? + ); + + Ok(()) + } + + #[test] + fn test_propagate_or() -> Result<()> { + let expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Or, + Arc::new(Column::new("b", 1)), + )); + let parent = Interval::CERTAINLY_FALSE; + let children_set = vec![ + vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN], + vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_FALSE], + vec![&Interval::CERTAINLY_FALSE, &Interval::CERTAINLY_FALSE], + vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN], + ]; + for children in children_set { + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE], + ); + } + + let parent = Interval::CERTAINLY_FALSE; + let children_set = vec![ + vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN], + vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE], + ]; + for children in children_set { + assert_eq!(expr.propagate_constraints(&parent, &children)?, None,); + } + + let parent = Interval::CERTAINLY_TRUE; + let children = vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN]; + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE] + ); + + let parent = Interval::CERTAINLY_TRUE; + let children = vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN]; + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + // Empty means unchanged intervals. + vec![] + ); + + Ok(()) + } + + #[test] + fn test_propagate_certainly_false_and() -> Result<()> { + let expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::And, + Arc::new(Column::new("b", 1)), + )); + let parent = Interval::CERTAINLY_FALSE; + let children_and_results_set = vec![ + ( + vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN], + vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_FALSE], + ), + ( + vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE], + vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE], + ), + ( + vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN], + // Empty means unchanged intervals. + vec![], + ), + ( + vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN], + vec![], + ), + ]; + for (children, result) in children_and_results_set { + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + result + ); + } + Ok(()) } } diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs deleted file mode 100644 index 3ed228517fd2..000000000000 --- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs +++ /dev/null @@ -1,1822 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -//http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Interval arithmetic library - -use std::borrow::Borrow; -use std::fmt; -use std::fmt::{Display, Formatter}; -use std::ops::{AddAssign, SubAssign}; - -use crate::aggregate::min_max::{max, min}; -use crate::intervals::rounding::{alter_fp_rounding_mode, next_down, next_up}; - -use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::DataType; -use arrow_array::ArrowNativeTypeOp; -use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::type_coercion::binary::get_result_type; -use datafusion_expr::Operator; - -/// This type represents a single endpoint of an [`Interval`]. An -/// endpoint can be open (does not include the endpoint) or closed -/// (includes the endpoint). -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct IntervalBound { - pub value: ScalarValue, - /// If true, interval does not include `value` - pub open: bool, -} - -impl IntervalBound { - /// Creates a new `IntervalBound` object using the given value. - pub const fn new(value: ScalarValue, open: bool) -> IntervalBound { - IntervalBound { value, open } - } - - /// Creates a new "open" interval (does not include the `value` - /// bound) - pub const fn new_open(value: ScalarValue) -> IntervalBound { - IntervalBound::new(value, true) - } - - /// Creates a new "closed" interval (includes the `value` - /// bound) - pub const fn new_closed(value: ScalarValue) -> IntervalBound { - IntervalBound::new(value, false) - } - - /// This convenience function creates an unbounded interval endpoint. - pub fn make_unbounded>(data_type: T) -> Result { - ScalarValue::try_from(data_type.borrow()).map(|v| IntervalBound::new(v, true)) - } - - /// This convenience function returns the data type associated with this - /// `IntervalBound`. - pub fn get_datatype(&self) -> DataType { - self.value.data_type() - } - - /// This convenience function checks whether the `IntervalBound` represents - /// an unbounded interval endpoint. - pub fn is_unbounded(&self) -> bool { - self.value.is_null() - } - - /// This function casts the `IntervalBound` to the given data type. - pub(crate) fn cast_to( - &self, - data_type: &DataType, - cast_options: &CastOptions, - ) -> Result { - cast_scalar_value(&self.value, data_type, cast_options) - .map(|value| IntervalBound::new(value, self.open)) - } - - /// This function adds the given `IntervalBound` to this `IntervalBound`. - /// The result is unbounded if either is; otherwise, their values are - /// added. The result is closed if both original bounds are closed, or open - /// otherwise. - pub fn add>( - &self, - other: T, - ) -> Result { - let rhs = other.borrow(); - if self.is_unbounded() || rhs.is_unbounded() { - return IntervalBound::make_unbounded(get_result_type( - &self.get_datatype(), - &Operator::Plus, - &rhs.get_datatype(), - )?); - } - match self.get_datatype() { - DataType::Float64 | DataType::Float32 => { - alter_fp_rounding_mode::(&self.value, &rhs.value, |lhs, rhs| { - lhs.add(rhs) - }) - } - _ => self.value.add(&rhs.value), - } - .map(|v| IntervalBound::new(v, self.open || rhs.open)) - } - - /// This function subtracts the given `IntervalBound` from `self`. - /// The result is unbounded if either is; otherwise, their values are - /// subtracted. The result is closed if both original bounds are closed, - /// or open otherwise. - pub fn sub>( - &self, - other: T, - ) -> Result { - let rhs = other.borrow(); - if self.is_unbounded() || rhs.is_unbounded() { - return IntervalBound::make_unbounded(get_result_type( - &self.get_datatype(), - &Operator::Minus, - &rhs.get_datatype(), - )?); - } - match self.get_datatype() { - DataType::Float64 | DataType::Float32 => { - alter_fp_rounding_mode::(&self.value, &rhs.value, |lhs, rhs| { - lhs.sub(rhs) - }) - } - _ => self.value.sub(&rhs.value), - } - .map(|v| IntervalBound::new(v, self.open || rhs.open)) - } - - /// This function chooses one of the given `IntervalBound`s according to - /// the given function `decide`. The result is unbounded if both are. If - /// only one of the arguments is unbounded, the other one is chosen by - /// default. If neither is unbounded, the function `decide` is used. - pub fn choose( - first: &IntervalBound, - second: &IntervalBound, - decide: fn(&ScalarValue, &ScalarValue) -> Result, - ) -> Result { - Ok(if first.is_unbounded() { - second.clone() - } else if second.is_unbounded() { - first.clone() - } else if first.value != second.value { - let chosen = decide(&first.value, &second.value)?; - if chosen.eq(&first.value) { - first.clone() - } else { - second.clone() - } - } else { - IntervalBound::new(second.value.clone(), first.open || second.open) - }) - } -} - -impl Display for IntervalBound { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "IntervalBound [{}]", self.value) - } -} - -/// This type represents an interval, which is used to calculate reliable -/// bounds for expressions. Currently, we only support addition and -/// subtraction, but more capabilities will be added in the future. -/// Upper/lower bounds having NULL values indicate an unbounded side. For -/// example; [10, 20], [10, ∞), (-∞, 100] and (-∞, ∞) are all valid intervals. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Interval { - pub lower: IntervalBound, - pub upper: IntervalBound, -} - -impl Default for Interval { - fn default() -> Self { - Interval::new( - IntervalBound::new(ScalarValue::Null, true), - IntervalBound::new(ScalarValue::Null, true), - ) - } -} - -impl Display for Interval { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!( - f, - "{}{}, {}{}", - if self.lower.open { "(" } else { "[" }, - self.lower.value, - self.upper.value, - if self.upper.open { ")" } else { "]" } - ) - } -} - -impl Interval { - /// Creates a new interval object using the given bounds. - /// For boolean intervals, having an open false lower bound is equivalent - /// to having a true closed lower bound. Similarly, open true upper bound - /// is equivalent to having a false closed upper bound. Also for boolean - /// intervals, having an unbounded left endpoint is equivalent to having a - /// false closed lower bound, while having an unbounded right endpoint is - /// equivalent to having a true closed upper bound. Therefore; input - /// parameters to construct an Interval can have different types, but they - /// all result in [false, false], [false, true] or [true, true]. - pub fn new(lower: IntervalBound, upper: IntervalBound) -> Interval { - // Boolean intervals need a special handling. - if let ScalarValue::Boolean(_) = lower.value { - let standardized_lower = match lower.value { - ScalarValue::Boolean(None) if lower.open => { - ScalarValue::Boolean(Some(false)) - } - ScalarValue::Boolean(Some(false)) if lower.open => { - ScalarValue::Boolean(Some(true)) - } - // The rest may include some invalid interval cases. The validation of - // interval construction parameters will be implemented later. - // For now, let's return them unchanged. - _ => lower.value, - }; - let standardized_upper = match upper.value { - ScalarValue::Boolean(None) if upper.open => { - ScalarValue::Boolean(Some(true)) - } - ScalarValue::Boolean(Some(true)) if upper.open => { - ScalarValue::Boolean(Some(false)) - } - _ => upper.value, - }; - Interval { - lower: IntervalBound::new(standardized_lower, false), - upper: IntervalBound::new(standardized_upper, false), - } - } else { - Interval { lower, upper } - } - } - - pub fn make(lower: Option, upper: Option, open: (bool, bool)) -> Interval - where - ScalarValue: From>, - { - Interval::new( - IntervalBound::new(ScalarValue::from(lower), open.0), - IntervalBound::new(ScalarValue::from(upper), open.1), - ) - } - - /// Casts this interval to `data_type` using `cast_options`. - pub(crate) fn cast_to( - &self, - data_type: &DataType, - cast_options: &CastOptions, - ) -> Result { - let lower = self.lower.cast_to(data_type, cast_options)?; - let upper = self.upper.cast_to(data_type, cast_options)?; - Ok(Interval::new(lower, upper)) - } - - /// This function returns the data type of this interval. If both endpoints - /// do not have the same data type, returns an error. - pub fn get_datatype(&self) -> Result { - let lower_type = self.lower.get_datatype(); - let upper_type = self.upper.get_datatype(); - if lower_type == upper_type { - Ok(lower_type) - } else { - internal_err!( - "Interval bounds have different types: {lower_type} != {upper_type}" - ) - } - } - - /// Decide if this interval is certainly greater than, possibly greater than, - /// or can't be greater than `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn gt>(&self, other: T) -> Interval { - let rhs = other.borrow(); - let flags = if !self.upper.is_unbounded() - && !rhs.lower.is_unbounded() - && self.upper.value <= rhs.lower.value - { - // Values in this interval are certainly less than or equal to those - // in the given interval. - (false, false) - } else if !self.lower.is_unbounded() - && !rhs.upper.is_unbounded() - && self.lower.value >= rhs.upper.value - && (self.lower.value > rhs.upper.value || self.lower.open || rhs.upper.open) - { - // Values in this interval are certainly greater than those in the - // given interval. - (true, true) - } else { - // All outcomes are possible. - (false, true) - }; - - Interval::make(Some(flags.0), Some(flags.1), (false, false)) - } - - /// Decide if this interval is certainly greater than or equal to, possibly greater than - /// or equal to, or can't be greater than or equal to `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn gt_eq>(&self, other: T) -> Interval { - let rhs = other.borrow(); - let flags = if !self.lower.is_unbounded() - && !rhs.upper.is_unbounded() - && self.lower.value >= rhs.upper.value - { - // Values in this interval are certainly greater than or equal to those - // in the given interval. - (true, true) - } else if !self.upper.is_unbounded() - && !rhs.lower.is_unbounded() - && self.upper.value <= rhs.lower.value - && (self.upper.value < rhs.lower.value || self.upper.open || rhs.lower.open) - { - // Values in this interval are certainly less than those in the - // given interval. - (false, false) - } else { - // All outcomes are possible. - (false, true) - }; - - Interval::make(Some(flags.0), Some(flags.1), (false, false)) - } - - /// Decide if this interval is certainly less than, possibly less than, - /// or can't be less than `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn lt>(&self, other: T) -> Interval { - other.borrow().gt(self) - } - - /// Decide if this interval is certainly less than or equal to, possibly - /// less than or equal to, or can't be less than or equal to `other` by returning - /// [true, true], [false, true] or [false, false] respectively. - pub(crate) fn lt_eq>(&self, other: T) -> Interval { - other.borrow().gt_eq(self) - } - - /// Decide if this interval is certainly equal to, possibly equal to, - /// or can't be equal to `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn equal>(&self, other: T) -> Interval { - let rhs = other.borrow(); - let flags = if !self.lower.is_unbounded() - && (self.lower.value == self.upper.value) - && (rhs.lower.value == rhs.upper.value) - && (self.lower.value == rhs.lower.value) - { - (true, true) - } else if self.gt(rhs) == Interval::CERTAINLY_TRUE - || self.lt(rhs) == Interval::CERTAINLY_TRUE - { - (false, false) - } else { - (false, true) - }; - - Interval::make(Some(flags.0), Some(flags.1), (false, false)) - } - - /// Compute the logical conjunction of this (boolean) interval with the given boolean interval. - pub(crate) fn and>(&self, other: T) -> Result { - let rhs = other.borrow(); - match ( - &self.lower.value, - &self.upper.value, - &rhs.lower.value, - &rhs.upper.value, - ) { - ( - ScalarValue::Boolean(Some(self_lower)), - ScalarValue::Boolean(Some(self_upper)), - ScalarValue::Boolean(Some(other_lower)), - ScalarValue::Boolean(Some(other_upper)), - ) => { - let lower = *self_lower && *other_lower; - let upper = *self_upper && *other_upper; - - Ok(Interval { - lower: IntervalBound::new(ScalarValue::Boolean(Some(lower)), false), - upper: IntervalBound::new(ScalarValue::Boolean(Some(upper)), false), - }) - } - _ => internal_err!("Incompatible types for logical conjunction"), - } - } - - /// Compute the logical negation of this (boolean) interval. - pub(crate) fn not(&self) -> Result { - if !matches!(self.get_datatype()?, DataType::Boolean) { - return internal_err!( - "Cannot apply logical negation to non-boolean interval" - ); - } - if self == &Interval::CERTAINLY_TRUE { - Ok(Interval::CERTAINLY_FALSE) - } else if self == &Interval::CERTAINLY_FALSE { - Ok(Interval::CERTAINLY_TRUE) - } else { - Ok(Interval::UNCERTAIN) - } - } - - /// Compute the intersection of the interval with the given interval. - /// If the intersection is empty, return None. - pub(crate) fn intersect>( - &self, - other: T, - ) -> Result> { - let rhs = other.borrow(); - // If it is evident that the result is an empty interval, - // do not make any calculation and directly return None. - if (!self.lower.is_unbounded() - && !rhs.upper.is_unbounded() - && self.lower.value > rhs.upper.value) - || (!self.upper.is_unbounded() - && !rhs.lower.is_unbounded() - && self.upper.value < rhs.lower.value) - { - // This None value signals an empty interval. - return Ok(None); - } - - let lower = IntervalBound::choose(&self.lower, &rhs.lower, max)?; - let upper = IntervalBound::choose(&self.upper, &rhs.upper, min)?; - - let non_empty = lower.is_unbounded() - || upper.is_unbounded() - || lower.value != upper.value - || (!lower.open && !upper.open); - Ok(non_empty.then_some(Interval::new(lower, upper))) - } - - /// Decide if this interval is certainly contains, possibly contains, - /// or can't can't `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub fn contains>(&self, other: T) -> Result { - match self.intersect(other.borrow())? { - Some(intersection) => { - // Need to compare with same bounds close-ness. - if intersection.close_bounds() == other.borrow().clone().close_bounds() { - Ok(Interval::CERTAINLY_TRUE) - } else { - Ok(Interval::UNCERTAIN) - } - } - None => Ok(Interval::CERTAINLY_FALSE), - } - } - - /// Add the given interval (`other`) to this interval. Say we have - /// intervals [a1, b1] and [a2, b2], then their sum is [a1 + a2, b1 + b2]. - /// Note that this represents all possible values the sum can take if - /// one can choose single values arbitrarily from each of the operands. - pub fn add>(&self, other: T) -> Result { - let rhs = other.borrow(); - Ok(Interval::new( - self.lower.add::(&rhs.lower)?, - self.upper.add::(&rhs.upper)?, - )) - } - - /// Subtract the given interval (`other`) from this interval. Say we have - /// intervals [a1, b1] and [a2, b2], then their sum is [a1 - b2, b1 - a2]. - /// Note that this represents all possible values the difference can take - /// if one can choose single values arbitrarily from each of the operands. - pub fn sub>(&self, other: T) -> Result { - let rhs = other.borrow(); - Ok(Interval::new( - self.lower.sub::(&rhs.upper)?, - self.upper.sub::(&rhs.lower)?, - )) - } - - pub const CERTAINLY_FALSE: Interval = Interval { - lower: IntervalBound::new_closed(ScalarValue::Boolean(Some(false))), - upper: IntervalBound::new_closed(ScalarValue::Boolean(Some(false))), - }; - - pub const UNCERTAIN: Interval = Interval { - lower: IntervalBound::new_closed(ScalarValue::Boolean(Some(false))), - upper: IntervalBound::new_closed(ScalarValue::Boolean(Some(true))), - }; - - pub const CERTAINLY_TRUE: Interval = Interval { - lower: IntervalBound::new_closed(ScalarValue::Boolean(Some(true))), - upper: IntervalBound::new_closed(ScalarValue::Boolean(Some(true))), - }; - - /// Returns the cardinality of this interval, which is the number of all - /// distinct points inside it. - pub fn cardinality(&self) -> Result { - match self.get_datatype() { - Ok(data_type) if data_type.is_integer() => { - if let Some(diff) = self.upper.value.distance(&self.lower.value) { - Ok(calculate_cardinality_based_on_bounds( - self.lower.open, - self.upper.open, - diff as u64, - )) - } else { - exec_err!("Cardinality cannot be calculated for {:?}", self) - } - } - // Ordering floating-point numbers according to their binary representations - // coincide with their natural ordering. Therefore, we can consider their - // binary representations as "indices" and subtract them. For details, see: - // https://stackoverflow.com/questions/8875064/how-many-distinct-floating-point-numbers-in-a-specific-range - Ok(data_type) if data_type.is_floating() => { - // If the minimum value is a negative number, we need to - // switch sides to ensure an unsigned result. - let (min, max) = if self.lower.value - < ScalarValue::new_zero(&self.lower.value.data_type())? - { - (self.upper.value.clone(), self.lower.value.clone()) - } else { - (self.lower.value.clone(), self.upper.value.clone()) - }; - - match (min, max) { - ( - ScalarValue::Float32(Some(lower)), - ScalarValue::Float32(Some(upper)), - ) => Ok(calculate_cardinality_based_on_bounds( - self.lower.open, - self.upper.open, - (upper.to_bits().sub_checked(lower.to_bits()))? as u64, - )), - ( - ScalarValue::Float64(Some(lower)), - ScalarValue::Float64(Some(upper)), - ) => Ok(calculate_cardinality_based_on_bounds( - self.lower.open, - self.upper.open, - upper.to_bits().sub_checked(lower.to_bits())?, - )), - _ => exec_err!( - "Cardinality cannot be calculated for the datatype {:?}", - data_type - ), - } - } - // If the cardinality cannot be calculated anyway, give an error. - _ => exec_err!("Cardinality cannot be calculated for {:?}", self), - } - } - - /// This function "closes" this interval; i.e. it modifies the endpoints so - /// that we end up with the narrowest possible closed interval containing - /// the original interval. - pub fn close_bounds(mut self) -> Interval { - if self.lower.open { - // Get next value - self.lower.value = next_value::(self.lower.value); - self.lower.open = false; - } - - if self.upper.open { - // Get previous value - self.upper.value = next_value::(self.upper.value); - self.upper.open = false; - } - - self - } -} - -trait OneTrait: Sized + std::ops::Add + std::ops::Sub { - fn one() -> Self; -} - -macro_rules! impl_OneTrait{ - ($($m:ty),*) => {$( impl OneTrait for $m { fn one() -> Self { 1 as $m } })*} -} -impl_OneTrait! {u8, u16, u32, u64, i8, i16, i32, i64} - -/// This function either increments or decrements its argument, depending on the `INC` value. -/// If `true`, it increments; otherwise it decrements the argument. -fn increment_decrement( - mut val: T, -) -> T { - if INC { - val.add_assign(T::one()); - } else { - val.sub_assign(T::one()); - } - val -} - -macro_rules! check_infinite_bounds { - ($value:expr, $val:expr, $type:ident, $inc:expr) => { - if ($val == $type::MAX && $inc) || ($val == $type::MIN && !$inc) { - return $value; - } - }; -} - -/// This function returns the next/previous value depending on the `ADD` value. -/// If `true`, it returns the next value; otherwise it returns the previous value. -fn next_value(value: ScalarValue) -> ScalarValue { - use ScalarValue::*; - match value { - Float32(Some(val)) => { - let new_float = if INC { next_up(val) } else { next_down(val) }; - Float32(Some(new_float)) - } - Float64(Some(val)) => { - let new_float = if INC { next_up(val) } else { next_down(val) }; - Float64(Some(new_float)) - } - Int8(Some(val)) => { - check_infinite_bounds!(value, val, i8, INC); - Int8(Some(increment_decrement::(val))) - } - Int16(Some(val)) => { - check_infinite_bounds!(value, val, i16, INC); - Int16(Some(increment_decrement::(val))) - } - Int32(Some(val)) => { - check_infinite_bounds!(value, val, i32, INC); - Int32(Some(increment_decrement::(val))) - } - Int64(Some(val)) => { - check_infinite_bounds!(value, val, i64, INC); - Int64(Some(increment_decrement::(val))) - } - UInt8(Some(val)) => { - check_infinite_bounds!(value, val, u8, INC); - UInt8(Some(increment_decrement::(val))) - } - UInt16(Some(val)) => { - check_infinite_bounds!(value, val, u16, INC); - UInt16(Some(increment_decrement::(val))) - } - UInt32(Some(val)) => { - check_infinite_bounds!(value, val, u32, INC); - UInt32(Some(increment_decrement::(val))) - } - UInt64(Some(val)) => { - check_infinite_bounds!(value, val, u64, INC); - UInt64(Some(increment_decrement::(val))) - } - _ => value, // Unsupported datatypes - } -} - -/// This function computes the cardinality ratio of the given intervals. -pub fn cardinality_ratio( - initial_interval: &Interval, - final_interval: &Interval, -) -> Result { - Ok(final_interval.cardinality()? as f64 / initial_interval.cardinality()? as f64) -} - -pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { - match *op { - Operator::Eq => Ok(lhs.equal(rhs)), - Operator::NotEq => Ok(lhs.equal(rhs).not()?), - Operator::Gt => Ok(lhs.gt(rhs)), - Operator::GtEq => Ok(lhs.gt_eq(rhs)), - Operator::Lt => Ok(lhs.lt(rhs)), - Operator::LtEq => Ok(lhs.lt_eq(rhs)), - Operator::And => lhs.and(rhs), - Operator::Plus => lhs.add(rhs), - Operator::Minus => lhs.sub(rhs), - _ => Ok(Interval::default()), - } -} - -/// Cast scalar value to the given data type using an arrow kernel. -fn cast_scalar_value( - value: &ScalarValue, - data_type: &DataType, - cast_options: &CastOptions, -) -> Result { - let cast_array = cast_with_options(&value.to_array(), data_type, cast_options)?; - ScalarValue::try_from_array(&cast_array, 0) -} - -/// This function calculates the final cardinality result by inspecting the endpoints of the interval. -fn calculate_cardinality_based_on_bounds( - lower_open: bool, - upper_open: bool, - diff: u64, -) -> u64 { - match (lower_open, upper_open) { - (false, false) => diff + 1, - (true, true) => diff - 1, - _ => diff, - } -} - -/// An [Interval] that also tracks null status using a boolean interval. -/// -/// This represents values that may be in a particular range or be null. -/// -/// # Examples -/// -/// ``` -/// use arrow::datatypes::DataType; -/// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; -/// use datafusion_common::ScalarValue; -/// -/// // [1, 2) U {NULL} -/// NullableInterval::MaybeNull { -/// values: Interval::make(Some(1), Some(2), (false, true)), -/// }; -/// -/// // (0, ∞) -/// NullableInterval::NotNull { -/// values: Interval::make(Some(0), None, (true, true)), -/// }; -/// -/// // {NULL} -/// NullableInterval::Null { datatype: DataType::Int32 }; -/// -/// // {4} -/// NullableInterval::from(ScalarValue::Int32(Some(4))); -/// ``` -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum NullableInterval { - /// The value is always null in this interval - /// - /// This is typed so it can be used in physical expressions, which don't do - /// type coercion. - Null { datatype: DataType }, - /// The value may or may not be null in this interval. If it is non null its value is within - /// the specified values interval - MaybeNull { values: Interval }, - /// The value is definitely not null in this interval and is within values - NotNull { values: Interval }, -} - -impl Default for NullableInterval { - fn default() -> Self { - NullableInterval::MaybeNull { - values: Interval::default(), - } - } -} - -impl Display for NullableInterval { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Self::Null { .. } => write!(f, "NullableInterval: {{NULL}}"), - Self::MaybeNull { values } => { - write!(f, "NullableInterval: {} U {{NULL}}", values) - } - Self::NotNull { values } => write!(f, "NullableInterval: {}", values), - } - } -} - -impl From for NullableInterval { - /// Create an interval that represents a single value. - fn from(value: ScalarValue) -> Self { - if value.is_null() { - Self::Null { - datatype: value.data_type(), - } - } else { - Self::NotNull { - values: Interval::new( - IntervalBound::new(value.clone(), false), - IntervalBound::new(value, false), - ), - } - } - } -} - -impl NullableInterval { - /// Get the values interval, or None if this interval is definitely null. - pub fn values(&self) -> Option<&Interval> { - match self { - Self::Null { .. } => None, - Self::MaybeNull { values } | Self::NotNull { values } => Some(values), - } - } - - /// Get the data type - pub fn get_datatype(&self) -> Result { - match self { - Self::Null { datatype } => Ok(datatype.clone()), - Self::MaybeNull { values } | Self::NotNull { values } => { - values.get_datatype() - } - } - } - - /// Return true if the value is definitely true (and not null). - pub fn is_certainly_true(&self) -> bool { - match self { - Self::Null { .. } | Self::MaybeNull { .. } => false, - Self::NotNull { values } => values == &Interval::CERTAINLY_TRUE, - } - } - - /// Return true if the value is definitely false (and not null). - pub fn is_certainly_false(&self) -> bool { - match self { - Self::Null { .. } => false, - Self::MaybeNull { .. } => false, - Self::NotNull { values } => values == &Interval::CERTAINLY_FALSE, - } - } - - /// Perform logical negation on a boolean nullable interval. - fn not(&self) -> Result { - match self { - Self::Null { datatype } => Ok(Self::Null { - datatype: datatype.clone(), - }), - Self::MaybeNull { values } => Ok(Self::MaybeNull { - values: values.not()?, - }), - Self::NotNull { values } => Ok(Self::NotNull { - values: values.not()?, - }), - } - } - - /// Apply the given operator to this interval and the given interval. - /// - /// # Examples - /// - /// ``` - /// use datafusion_common::ScalarValue; - /// use datafusion_expr::Operator; - /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; - /// - /// // 4 > 3 -> true - /// let lhs = NullableInterval::from(ScalarValue::Int32(Some(4))); - /// let rhs = NullableInterval::from(ScalarValue::Int32(Some(3))); - /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); - /// assert_eq!(result, NullableInterval::from(ScalarValue::Boolean(Some(true)))); - /// - /// // [1, 3) > NULL -> NULL - /// let lhs = NullableInterval::NotNull { - /// values: Interval::make(Some(1), Some(3), (false, true)), - /// }; - /// let rhs = NullableInterval::from(ScalarValue::Int32(None)); - /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); - /// assert_eq!(result.single_value(), Some(ScalarValue::Boolean(None))); - /// - /// // [1, 3] > [2, 4] -> [false, true] - /// let lhs = NullableInterval::NotNull { - /// values: Interval::make(Some(1), Some(3), (false, false)), - /// }; - /// let rhs = NullableInterval::NotNull { - /// values: Interval::make(Some(2), Some(4), (false, false)), - /// }; - /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); - /// // Both inputs are valid (non-null), so result must be non-null - /// assert_eq!(result, NullableInterval::NotNull { - /// // Uncertain whether inequality is true or false - /// values: Interval::UNCERTAIN, - /// }); - /// - /// ``` - pub fn apply_operator(&self, op: &Operator, rhs: &Self) -> Result { - match op { - Operator::IsDistinctFrom => { - let values = match (self, rhs) { - // NULL is distinct from NULL -> False - (Self::Null { .. }, Self::Null { .. }) => Interval::CERTAINLY_FALSE, - // x is distinct from y -> x != y, - // if at least one of them is never null. - (Self::NotNull { .. }, _) | (_, Self::NotNull { .. }) => { - let lhs_values = self.values(); - let rhs_values = rhs.values(); - match (lhs_values, rhs_values) { - (Some(lhs_values), Some(rhs_values)) => { - lhs_values.equal(rhs_values).not()? - } - (Some(_), None) | (None, Some(_)) => Interval::CERTAINLY_TRUE, - (None, None) => unreachable!("Null case handled above"), - } - } - _ => Interval::UNCERTAIN, - }; - // IsDistinctFrom never returns null. - Ok(Self::NotNull { values }) - } - Operator::IsNotDistinctFrom => self - .apply_operator(&Operator::IsDistinctFrom, rhs) - .map(|i| i.not())?, - _ => { - if let (Some(left_values), Some(right_values)) = - (self.values(), rhs.values()) - { - let values = apply_operator(op, left_values, right_values)?; - match (self, rhs) { - (Self::NotNull { .. }, Self::NotNull { .. }) => { - Ok(Self::NotNull { values }) - } - _ => Ok(Self::MaybeNull { values }), - } - } else if op.is_comparison_operator() { - Ok(Self::Null { - datatype: DataType::Boolean, - }) - } else { - Ok(Self::Null { - datatype: self.get_datatype()?, - }) - } - } - } - } - - /// Determine if this interval contains the given interval. Returns a boolean - /// interval that is [true, true] if this interval is a superset of the - /// given interval, [false, false] if this interval is disjoint from the - /// given interval, and [false, true] otherwise. - pub fn contains>(&self, other: T) -> Result { - let rhs = other.borrow(); - if let (Some(left_values), Some(right_values)) = (self.values(), rhs.values()) { - let values = left_values.contains(right_values)?; - match (self, rhs) { - (Self::NotNull { .. }, Self::NotNull { .. }) => { - Ok(Self::NotNull { values }) - } - _ => Ok(Self::MaybeNull { values }), - } - } else { - Ok(Self::Null { - datatype: DataType::Boolean, - }) - } - } - - /// If the interval has collapsed to a single value, return that value. - /// - /// Otherwise returns None. - /// - /// # Examples - /// - /// ``` - /// use datafusion_common::ScalarValue; - /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; - /// - /// let interval = NullableInterval::from(ScalarValue::Int32(Some(4))); - /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4)))); - /// - /// let interval = NullableInterval::from(ScalarValue::Int32(None)); - /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(None))); - /// - /// let interval = NullableInterval::MaybeNull { - /// values: Interval::make(Some(1), Some(4), (false, true)), - /// }; - /// assert_eq!(interval.single_value(), None); - /// ``` - pub fn single_value(&self) -> Option { - match self { - Self::Null { datatype } => { - Some(ScalarValue::try_from(datatype).unwrap_or(ScalarValue::Null)) - } - Self::MaybeNull { values } | Self::NotNull { values } - if values.lower.value == values.upper.value - && !values.lower.is_unbounded() => - { - Some(values.lower.value.clone()) - } - _ => None, - } - } -} - -#[cfg(test)] -mod tests { - use super::next_value; - use crate::intervals::{Interval, IntervalBound}; - - use arrow_schema::DataType; - use datafusion_common::{Result, ScalarValue}; - - fn open_open(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (true, true)) - } - - fn open_closed(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (true, false)) - } - - fn closed_open(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (false, true)) - } - - fn closed_closed(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (false, false)) - } - - #[test] - fn intersect_test() -> Result<()> { - let possible_cases = vec![ - (Some(1000_i64), None, None, None, Some(1000_i64), None), - (None, Some(1000_i64), None, None, None, Some(1000_i64)), - (None, None, Some(1000_i64), None, Some(1000_i64), None), - (None, None, None, Some(1000_i64), None, Some(1000_i64)), - ( - Some(1000_i64), - None, - Some(1000_i64), - None, - Some(1000_i64), - None, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - Some(999_i64), - Some(1000_i64), - ), - (None, None, None, None, None, None), - ]; - - for case in possible_cases { - assert_eq!( - open_open(case.0, case.1).intersect(open_open(case.2, case.3))?, - Some(open_open(case.4, case.5)) - ) - } - - let empty_cases = vec![ - (None, Some(1000_i64), Some(1001_i64), None), - (Some(1001_i64), None, None, Some(1000_i64)), - (None, Some(1000_i64), Some(1001_i64), Some(1002_i64)), - (Some(1001_i64), Some(1002_i64), None, Some(1000_i64)), - ]; - - for case in empty_cases { - assert_eq!( - open_open(case.0, case.1).intersect(open_open(case.2, case.3))?, - None - ) - } - - Ok(()) - } - - #[test] - fn gt_test() { - let cases = vec![ - (Some(1000_i64), None, None, None, false, true), - (None, Some(1000_i64), None, None, false, true), - (None, None, Some(1000_i64), None, false, true), - (None, None, None, Some(1000_i64), false, true), - (None, Some(1000_i64), Some(1000_i64), None, false, false), - (None, Some(1000_i64), Some(1001_i64), None, false, false), - (Some(1000_i64), None, Some(1000_i64), None, false, true), - ( - None, - Some(1000_i64), - Some(1001_i64), - Some(1002_i64), - false, - false, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - false, - true, - ), - ( - Some(1002_i64), - None, - Some(999_i64), - Some(1002_i64), - true, - true, - ), - ( - Some(1003_i64), - None, - Some(999_i64), - Some(1002_i64), - true, - true, - ), - (None, None, None, None, false, true), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).gt(open_open(case.2, case.3)), - closed_closed(Some(case.4), Some(case.5)) - ); - } - } - - #[test] - fn lt_test() { - let cases = vec![ - (Some(1000_i64), None, None, None, false, true), - (None, Some(1000_i64), None, None, false, true), - (None, None, Some(1000_i64), None, false, true), - (None, None, None, Some(1000_i64), false, true), - (None, Some(1000_i64), Some(1000_i64), None, true, true), - (None, Some(1000_i64), Some(1001_i64), None, true, true), - (Some(1000_i64), None, Some(1000_i64), None, false, true), - ( - None, - Some(1000_i64), - Some(1001_i64), - Some(1002_i64), - true, - true, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - false, - true, - ), - (None, None, None, None, false, true), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).lt(open_open(case.2, case.3)), - closed_closed(Some(case.4), Some(case.5)) - ); - } - } - - #[test] - fn and_test() -> Result<()> { - let cases = vec![ - (false, true, false, false, false, false), - (false, false, false, true, false, false), - (false, true, false, true, false, true), - (false, true, true, true, false, true), - (false, false, false, false, false, false), - (true, true, true, true, true, true), - ]; - - for case in cases { - assert_eq!( - open_open(Some(case.0), Some(case.1)) - .and(open_open(Some(case.2), Some(case.3)))?, - open_open(Some(case.4), Some(case.5)) - ); - } - Ok(()) - } - - #[test] - fn add_test() -> Result<()> { - let cases = vec![ - (Some(1000_i64), None, None, None, None, None), - (None, Some(1000_i64), None, None, None, None), - (None, None, Some(1000_i64), None, None, None), - (None, None, None, Some(1000_i64), None, None), - ( - Some(1000_i64), - None, - Some(1000_i64), - None, - Some(2000_i64), - None, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - None, - Some(2002_i64), - ), - (None, Some(1000_i64), Some(1000_i64), None, None, None), - ( - Some(2001_i64), - Some(1_i64), - Some(1005_i64), - Some(-999_i64), - Some(3006_i64), - Some(-998_i64), - ), - (None, None, None, None, None, None), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).add(open_open(case.2, case.3))?, - open_open(case.4, case.5) - ); - } - Ok(()) - } - - #[test] - fn sub_test() -> Result<()> { - let cases = vec![ - (Some(1000_i64), None, None, None, None, None), - (None, Some(1000_i64), None, None, None, None), - (None, None, Some(1000_i64), None, None, None), - (None, None, None, Some(1000_i64), None, None), - (Some(1000_i64), None, Some(1000_i64), None, None, None), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - None, - Some(1_i64), - ), - ( - None, - Some(1000_i64), - Some(1000_i64), - None, - None, - Some(0_i64), - ), - ( - Some(2001_i64), - Some(1000_i64), - Some(1005), - Some(999_i64), - Some(1002_i64), - Some(-5_i64), - ), - (None, None, None, None, None, None), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).sub(open_open(case.2, case.3))?, - open_open(case.4, case.5) - ); - } - Ok(()) - } - - #[test] - fn sub_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - closed_open(Some(200_i64), None), - open_closed(None, Some(0_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - open_closed(Some(300_i64), Some(150_i64)), - closed_open(Some(-50_i64), Some(-100_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - open_open(Some(200_i64), None), - open_open(None, Some(0_i64)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_closed(Some(11_i64), Some(11_i64)), - closed_closed(Some(-10_i64), Some(-10_i64)), - ), - ]; - for case in cases { - assert_eq!(case.0.sub(case.1)?, case.2) - } - Ok(()) - } - - #[test] - fn add_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(200_i64)), - open_closed(None, Some(400_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - closed_open(Some(-300_i64), Some(150_i64)), - closed_open(Some(-200_i64), Some(350_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - open_open(Some(200_i64), None), - open_open(Some(300_i64), None), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_closed(Some(11_i64), Some(11_i64)), - closed_closed(Some(12_i64), Some(12_i64)), - ), - ]; - for case in cases { - assert_eq!(case.0.add(case.1)?, case.2) - } - Ok(()) - } - - #[test] - fn lt_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - open_open(Some(100_i64), Some(200_i64)), - closed_closed(Some(0_i64), Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ]; - for case in cases { - assert_eq!(case.0.lt(case.1), case.2) - } - Ok(()) - } - - #[test] - fn gt_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - open_open(Some(100_i64), Some(200_i64)), - closed_closed(Some(0_i64), Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ]; - for case in cases { - assert_eq!(case.0.gt(case.1), case.2) - } - Ok(()) - } - - #[test] - fn lt_eq_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ]; - for case in cases { - assert_eq!(case.0.lt_eq(case.1), case.2) - } - Ok(()) - } - - #[test] - fn gt_eq_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ]; - for case in cases { - assert_eq!(case.0.gt_eq(case.1), case.2) - } - Ok(()) - } - - #[test] - fn intersect_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - Some(closed_closed(Some(100_i64), Some(100_i64))), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - None, - ), - ( - open_open(Some(100_i64), Some(200_i64)), - closed_closed(Some(0_i64), Some(100_i64)), - None, - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - Some(closed_closed(Some(2_i64), Some(2_i64))), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - None, - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - None, - ), - ( - closed_closed(Some(1_i64), Some(3_i64)), - open_open(Some(1_i64), Some(2_i64)), - Some(open_open(Some(1_i64), Some(2_i64))), - ), - ]; - for case in cases { - assert_eq!(case.0.intersect(case.1)?, case.2) - } - Ok(()) - } - - // This function tests if valid constructions produce standardized objects - // ([false, false], [false, true], [true, true]) for boolean intervals. - #[test] - fn non_standard_interval_constructs() { - use ScalarValue::Boolean; - let cases = vec![ - ( - IntervalBound::new(Boolean(None), true), - IntervalBound::new(Boolean(Some(true)), false), - closed_closed(Some(false), Some(true)), - ), - ( - IntervalBound::new(Boolean(None), true), - IntervalBound::new(Boolean(Some(true)), true), - closed_closed(Some(false), Some(false)), - ), - ( - IntervalBound::new(Boolean(Some(false)), false), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(false), Some(true)), - ), - ( - IntervalBound::new(Boolean(Some(true)), false), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(true), Some(true)), - ), - ( - IntervalBound::new(Boolean(None), true), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(false), Some(true)), - ), - ( - IntervalBound::new(Boolean(Some(false)), true), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(true), Some(true)), - ), - ]; - - for case in cases { - assert_eq!(Interval::new(case.0, case.1), case.2) - } - } - - macro_rules! capture_mode_change { - ($TYPE:ty) => { - paste::item! { - capture_mode_change_helper!([], - [], - $TYPE); - } - }; - } - - macro_rules! capture_mode_change_helper { - ($TEST_FN_NAME:ident, $CREATE_FN_NAME:ident, $TYPE:ty) => { - fn $CREATE_FN_NAME(lower: $TYPE, upper: $TYPE) -> Interval { - Interval::make(Some(lower as $TYPE), Some(upper as $TYPE), (true, true)) - } - - fn $TEST_FN_NAME(input: ($TYPE, $TYPE), expect_low: bool, expect_high: bool) { - assert!(expect_low || expect_high); - let interval1 = $CREATE_FN_NAME(input.0, input.0); - let interval2 = $CREATE_FN_NAME(input.1, input.1); - let result = interval1.add(&interval2).unwrap(); - let without_fe = $CREATE_FN_NAME(input.0 + input.1, input.0 + input.1); - assert!( - (!expect_low || result.lower.value < without_fe.lower.value) - && (!expect_high || result.upper.value > without_fe.upper.value) - ); - } - }; - } - - capture_mode_change!(f32); - capture_mode_change!(f64); - - #[cfg(all( - any(target_arch = "x86_64", target_arch = "aarch64"), - not(target_os = "windows") - ))] - #[test] - fn test_add_intervals_lower_affected_f32() { - // Lower is affected - let lower = f32::from_bits(1073741887); //1000000000000000000000000111111 - let upper = f32::from_bits(1098907651); //1000001100000000000000000000011 - capture_mode_change_f32((lower, upper), true, false); - - // Upper is affected - let lower = f32::from_bits(1072693248); //111111111100000000000000000000 - let upper = f32::from_bits(715827883); //101010101010101010101010101011 - capture_mode_change_f32((lower, upper), false, true); - - // Lower is affected - let lower = 1.0; // 0x3FF0000000000000 - let upper = 0.3; // 0x3FD3333333333333 - capture_mode_change_f64((lower, upper), true, false); - - // Upper is affected - let lower = 1.4999999999999998; // 0x3FF7FFFFFFFFFFFF - let upper = 0.000_000_000_000_000_022_044_604_925_031_31; // 0x3C796A6B413BB21F - capture_mode_change_f64((lower, upper), false, true); - } - - #[cfg(any( - not(any(target_arch = "x86_64", target_arch = "aarch64")), - target_os = "windows" - ))] - #[test] - fn test_next_impl_add_intervals_f64() { - let lower = 1.5; - let upper = 1.5; - capture_mode_change_f64((lower, upper), true, true); - - let lower = 1.5; - let upper = 1.5; - capture_mode_change_f32((lower, upper), true, true); - } - - #[test] - fn test_cardinality_of_intervals() -> Result<()> { - // In IEEE 754 standard for floating-point arithmetic, if we keep the sign and exponent fields same, - // we can represent 4503599627370496 different numbers by changing the mantissa - // (4503599627370496 = 2^52, since there are 52 bits in mantissa, and 2^23 = 8388608 for f32). - let distinct_f64 = 4503599627370496; - let distinct_f32 = 8388608; - let intervals = [ - Interval::new( - IntervalBound::new(ScalarValue::from(0.25), false), - IntervalBound::new(ScalarValue::from(0.50), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(0.5), false), - IntervalBound::new(ScalarValue::from(1.0), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(1.0), false), - IntervalBound::new(ScalarValue::from(2.0), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(32.0), false), - IntervalBound::new(ScalarValue::from(64.0), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(-0.50), false), - IntervalBound::new(ScalarValue::from(-0.25), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(-32.0), false), - IntervalBound::new(ScalarValue::from(-16.0), true), - ), - ]; - for interval in intervals { - assert_eq!(interval.cardinality()?, distinct_f64); - } - - let intervals = [ - Interval::new( - IntervalBound::new(ScalarValue::from(0.25_f32), false), - IntervalBound::new(ScalarValue::from(0.50_f32), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(-1_f32), false), - IntervalBound::new(ScalarValue::from(-0.5_f32), true), - ), - ]; - for interval in intervals { - assert_eq!(interval.cardinality()?, distinct_f32); - } - - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(-0.0625), false), - IntervalBound::new(ScalarValue::from(0.0625), true), - ); - assert_eq!(interval.cardinality()?, distinct_f64 * 2_048); - - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(-0.0625_f32), false), - IntervalBound::new(ScalarValue::from(0.0625_f32), true), - ); - assert_eq!(interval.cardinality()?, distinct_f32 * 256); - - Ok(()) - } - - #[test] - fn test_next_value() -> Result<()> { - // integer increment / decrement - let zeros = vec![ - ScalarValue::new_zero(&DataType::UInt8)?, - ScalarValue::new_zero(&DataType::UInt16)?, - ScalarValue::new_zero(&DataType::UInt32)?, - ScalarValue::new_zero(&DataType::UInt64)?, - ScalarValue::new_zero(&DataType::Int8)?, - ScalarValue::new_zero(&DataType::Int8)?, - ScalarValue::new_zero(&DataType::Int8)?, - ScalarValue::new_zero(&DataType::Int8)?, - ]; - - let ones = vec![ - ScalarValue::new_one(&DataType::UInt8)?, - ScalarValue::new_one(&DataType::UInt16)?, - ScalarValue::new_one(&DataType::UInt32)?, - ScalarValue::new_one(&DataType::UInt64)?, - ScalarValue::new_one(&DataType::Int8)?, - ScalarValue::new_one(&DataType::Int8)?, - ScalarValue::new_one(&DataType::Int8)?, - ScalarValue::new_one(&DataType::Int8)?, - ]; - - zeros.into_iter().zip(ones).for_each(|(z, o)| { - assert_eq!(next_value::(z.clone()), o); - assert_eq!(next_value::(o), z); - }); - - // floating value increment / decrement - let values = vec![ - ScalarValue::new_zero(&DataType::Float32)?, - ScalarValue::new_zero(&DataType::Float64)?, - ]; - - let eps = vec![ - ScalarValue::Float32(Some(1e-6)), - ScalarValue::Float64(Some(1e-6)), - ]; - - values.into_iter().zip(eps).for_each(|(v, e)| { - assert!(next_value::(v.clone()).sub(v.clone()).unwrap().lt(&e)); - assert!(v.clone().sub(next_value::(v)).unwrap().lt(&e)); - }); - - // Min / Max values do not change for integer values - let min = vec![ - ScalarValue::UInt64(Some(u64::MIN)), - ScalarValue::Int8(Some(i8::MIN)), - ]; - let max = vec![ - ScalarValue::UInt64(Some(u64::MAX)), - ScalarValue::Int8(Some(i8::MAX)), - ]; - - min.into_iter().zip(max).for_each(|(min, max)| { - assert_eq!(next_value::(max.clone()), max); - assert_eq!(next_value::(min.clone()), min); - }); - - // Min / Max values results in infinity for floating point values - assert_eq!( - next_value::(ScalarValue::Float32(Some(f32::MAX))), - ScalarValue::Float32(Some(f32::INFINITY)) - ); - assert_eq!( - next_value::(ScalarValue::Float64(Some(f64::MIN))), - ScalarValue::Float64(Some(f64::NEG_INFINITY)) - ); - - Ok(()) - } - - #[test] - fn test_interval_display() { - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(0.25_f32), true), - IntervalBound::new(ScalarValue::from(0.50_f32), false), - ); - assert_eq!(format!("{}", interval), "(0.25, 0.5]"); - - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(0.25_f32), false), - IntervalBound::new(ScalarValue::from(0.50_f32), true), - ); - assert_eq!(format!("{}", interval), "[0.25, 0.5)"); - - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(0.25_f32), true), - IntervalBound::new(ScalarValue::from(0.50_f32), true), - ); - assert_eq!(format!("{}", interval), "(0.25, 0.5)"); - - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(0.25_f32), false), - IntervalBound::new(ScalarValue::from(0.50_f32), false), - ); - assert_eq!(format!("{}", interval), "[0.25, 0.5]"); - } -} diff --git a/datafusion/physical-expr/src/intervals/mod.rs b/datafusion/physical-expr/src/intervals/mod.rs index b89d1c59dc64..9752ca27b5a3 100644 --- a/datafusion/physical-expr/src/intervals/mod.rs +++ b/datafusion/physical-expr/src/intervals/mod.rs @@ -18,10 +18,5 @@ //! Interval arithmetic and constraint propagation library pub mod cp_solver; -pub mod interval_aritmetic; -pub mod rounding; pub mod test_utils; pub mod utils; - -pub use cp_solver::ExprIntervalGraph; -pub use interval_aritmetic::*; diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs index be3b17771303..03d13632104d 100644 --- a/datafusion/physical-expr/src/intervals/utils.rs +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -19,14 +19,16 @@ use std::sync::Arc; -use super::{Interval, IntervalBound}; use crate::{ - expressions::{BinaryExpr, CastExpr, Column, Literal}, + expressions::{BinaryExpr, CastExpr, Column, Literal, NegativeExpr}, PhysicalExpr, }; -use arrow_schema::DataType; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use arrow_schema::{DataType, SchemaRef}; +use datafusion_common::{ + internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::Operator; const MDN_DAY_MASK: i128 = 0xFFFF_FFFF_0000_0000_0000_0000; @@ -37,24 +39,42 @@ const DT_MS_MASK: i64 = 0xFFFF_FFFF; /// Currently, we do not support all [`PhysicalExpr`]s for interval calculations. /// We do not support every type of [`Operator`]s either. Over time, this check /// will relax as more types of `PhysicalExpr`s and `Operator`s are supported. -/// Currently, [`CastExpr`], [`BinaryExpr`], [`Column`] and [`Literal`] are supported. -pub fn check_support(expr: &Arc) -> bool { +/// Currently, [`CastExpr`], [`NegativeExpr`], [`BinaryExpr`], [`Column`] and [`Literal`] are supported. +pub fn check_support(expr: &Arc, schema: &SchemaRef) -> bool { let expr_any = expr.as_any(); - let expr_supported = if let Some(binary_expr) = expr_any.downcast_ref::() - { + if let Some(binary_expr) = expr_any.downcast_ref::() { is_operator_supported(binary_expr.op()) + && check_support(binary_expr.left(), schema) + && check_support(binary_expr.right(), schema) + } else if let Some(column) = expr_any.downcast_ref::() { + if let Ok(field) = schema.field_with_name(column.name()) { + is_datatype_supported(field.data_type()) + } else { + return false; + } + } else if let Some(literal) = expr_any.downcast_ref::() { + if let Ok(dt) = literal.data_type(schema) { + is_datatype_supported(&dt) + } else { + return false; + } + } else if let Some(cast) = expr_any.downcast_ref::() { + check_support(cast.expr(), schema) + } else if let Some(negative) = expr_any.downcast_ref::() { + check_support(negative.arg(), schema) } else { - expr_any.is::() || expr_any.is::() || expr_any.is::() - }; - expr_supported && expr.children().iter().all(check_support) + false + } } // This function returns the inverse operator of the given operator. -pub fn get_inverse_op(op: Operator) -> Operator { +pub fn get_inverse_op(op: Operator) -> Result { match op { - Operator::Plus => Operator::Minus, - Operator::Minus => Operator::Plus, - _ => unreachable!(), + Operator::Plus => Ok(Operator::Minus), + Operator::Minus => Ok(Operator::Plus), + Operator::Multiply => Ok(Operator::Divide), + Operator::Divide => Ok(Operator::Multiply), + _ => internal_err!("Interval arithmetic does not support the operator {}", op), } } @@ -70,6 +90,8 @@ pub fn is_operator_supported(op: &Operator) -> bool { | &Operator::Lt | &Operator::LtEq | &Operator::Eq + | &Operator::Multiply + | &Operator::Divide ) } @@ -93,36 +115,26 @@ pub fn is_datatype_supported(data_type: &DataType) -> bool { /// Converts an [`Interval`] of time intervals to one of `Duration`s, if applicable. Otherwise, returns [`None`]. pub fn convert_interval_type_to_duration(interval: &Interval) -> Option { if let (Some(lower), Some(upper)) = ( - convert_interval_bound_to_duration(&interval.lower), - convert_interval_bound_to_duration(&interval.upper), + convert_interval_bound_to_duration(interval.lower()), + convert_interval_bound_to_duration(interval.upper()), ) { - Some(Interval::new(lower, upper)) + Interval::try_new(lower, upper).ok() } else { None } } -/// Converts an [`IntervalBound`] containing a time interval to one containing a `Duration`, if applicable. Otherwise, returns [`None`]. +/// Converts an [`ScalarValue`] containing a time interval to one containing a `Duration`, if applicable. Otherwise, returns [`None`]. fn convert_interval_bound_to_duration( - interval_bound: &IntervalBound, -) -> Option { - match interval_bound.value { - ScalarValue::IntervalMonthDayNano(Some(mdn)) => { - interval_mdn_to_duration_ns(&mdn).ok().map(|duration| { - IntervalBound::new( - ScalarValue::DurationNanosecond(Some(duration)), - interval_bound.open, - ) - }) - } - ScalarValue::IntervalDayTime(Some(dt)) => { - interval_dt_to_duration_ms(&dt).ok().map(|duration| { - IntervalBound::new( - ScalarValue::DurationMillisecond(Some(duration)), - interval_bound.open, - ) - }) - } + interval_bound: &ScalarValue, +) -> Option { + match interval_bound { + ScalarValue::IntervalMonthDayNano(Some(mdn)) => interval_mdn_to_duration_ns(mdn) + .ok() + .map(|duration| ScalarValue::DurationNanosecond(Some(duration))), + ScalarValue::IntervalDayTime(Some(dt)) => interval_dt_to_duration_ms(dt) + .ok() + .map(|duration| ScalarValue::DurationMillisecond(Some(duration))), _ => None, } } @@ -130,28 +142,32 @@ fn convert_interval_bound_to_duration( /// Converts an [`Interval`] of `Duration`s to one of time intervals, if applicable. Otherwise, returns [`None`]. pub fn convert_duration_type_to_interval(interval: &Interval) -> Option { if let (Some(lower), Some(upper)) = ( - convert_duration_bound_to_interval(&interval.lower), - convert_duration_bound_to_interval(&interval.upper), + convert_duration_bound_to_interval(interval.lower()), + convert_duration_bound_to_interval(interval.upper()), ) { - Some(Interval::new(lower, upper)) + Interval::try_new(lower, upper).ok() } else { None } } -/// Converts an [`IntervalBound`] containing a `Duration` to one containing a time interval, if applicable. Otherwise, returns [`None`]. +/// Converts a [`ScalarValue`] containing a `Duration` to one containing a time interval, if applicable. Otherwise, returns [`None`]. fn convert_duration_bound_to_interval( - interval_bound: &IntervalBound, -) -> Option { - match interval_bound.value { - ScalarValue::DurationNanosecond(Some(duration)) => Some(IntervalBound::new( - ScalarValue::new_interval_mdn(0, 0, duration), - interval_bound.open, - )), - ScalarValue::DurationMillisecond(Some(duration)) => Some(IntervalBound::new( - ScalarValue::new_interval_dt(0, duration as i32), - interval_bound.open, - )), + interval_bound: &ScalarValue, +) -> Option { + match interval_bound { + ScalarValue::DurationNanosecond(Some(duration)) => { + Some(ScalarValue::new_interval_mdn(0, 0, *duration)) + } + ScalarValue::DurationMicrosecond(Some(duration)) => { + Some(ScalarValue::new_interval_mdn(0, 0, *duration * 1000)) + } + ScalarValue::DurationMillisecond(Some(duration)) => { + Some(ScalarValue::new_interval_dt(0, *duration as i32)) + } + ScalarValue::DurationSecond(Some(duration)) => { + Some(ScalarValue::new_interval_dt(0, *duration as i32 * 1000)) + } _ => None, } } @@ -164,14 +180,13 @@ fn interval_mdn_to_duration_ns(mdn: &i128) -> Result { let nanoseconds = mdn & MDN_NS_MASK; if months == 0 && days == 0 { - nanoseconds.try_into().map_err(|_| { - DataFusionError::Internal("Resulting duration exceeds i64::MAX".to_string()) - }) + nanoseconds + .try_into() + .map_err(|_| internal_datafusion_err!("Resulting duration exceeds i64::MAX")) } else { - Err(DataFusionError::Internal( + internal_err!( "The interval cannot have a non-zero month or day value for duration convertibility" - .to_string(), - )) + ) } } @@ -184,9 +199,8 @@ fn interval_dt_to_duration_ms(dt: &i64) -> Result { if days == 0 { Ok(milliseconds) } else { - Err(DataFusionError::Internal( + internal_err!( "The interval cannot have a non-zero day value for duration convertibility" - .to_string(), - )) + ) } } diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index e83dee2e6c80..fffa8f602d87 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -28,7 +28,6 @@ pub mod equivalence; pub mod execution_props; pub mod expressions; pub mod functions; -pub mod hash_utils; pub mod intervals; pub mod math_expressions; mod partitioning; @@ -54,23 +53,16 @@ pub use aggregate::groups_accumulator::{ }; pub use aggregate::AggregateExpr; pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; -pub use equivalence::{ - add_offset_to_lex_ordering, ordering_equivalence_properties_helper, - project_equivalence_properties, project_ordering_equivalence_properties, - EquivalenceProperties, EquivalentClass, OrderingEquivalenceProperties, - OrderingEquivalentClass, -}; - +pub use equivalence::EquivalenceProperties; pub use partitioning::{Distribution, Partitioning}; -pub use physical_expr::{physical_exprs_contains, PhysicalExpr, PhysicalExprRef}; +pub use physical_expr::{ + physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, + PhysicalExpr, PhysicalExprRef, +}; pub use planner::create_physical_expr; pub use scalar_function::ScalarFunctionExpr; pub use sort_expr::{ - LexOrdering, LexOrderingRef, LexOrderingReq, PhysicalSortExpr, + LexOrdering, LexOrderingRef, LexRequirement, LexRequirementRef, PhysicalSortExpr, PhysicalSortRequirement, }; -pub use sort_properties::update_ordering; -pub use utils::{ - expr_list_eq_any_order, expr_list_eq_strict_order, - normalize_out_expr_with_columns_map, reverse_order_bys, split_conjunction, -}; +pub use utils::{reverse_order_bys, split_conjunction}; diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index a9dc0bd58f15..af66862aecc5 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -19,8 +19,8 @@ use arrow::array::ArrayRef; use arrow::array::{ - BooleanArray, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, + BooleanArray, Decimal128Array, Decimal256Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, }; use arrow::datatypes::DataType; use arrow::error::ArrowError; @@ -701,6 +701,18 @@ macro_rules! make_try_abs_function { }}; } +macro_rules! make_decimal_abs_function { + ($ARRAY_TYPE:ident) => {{ + |args: &[ArrayRef]| { + let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE); + let res: $ARRAY_TYPE = array + .unary(|x| x.wrapping_abs()) + .with_data_type(args[0].data_type().clone()); + Ok(Arc::new(res) as ArrayRef) + } + }}; +} + /// Abs SQL function /// Return different implementations based on input datatype to reduce branches during execution pub(super) fn create_abs_function( @@ -723,20 +735,26 @@ pub(super) fn create_abs_function( | DataType::UInt32 | DataType::UInt64 => Ok(|args: &[ArrayRef]| Ok(args[0].clone())), - // Decimal should keep the same precision and scale by using `with_data_type()`. - // https://github.com/apache/arrow-rs/issues/4644 - DataType::Decimal128(_, _) => Ok(|args: &[ArrayRef]| { - let array = downcast_arg!(&args[0], "abs arg", Decimal128Array); - let res: Decimal128Array = array - .unary(i128::abs) - .with_data_type(args[0].data_type().clone()); - Ok(Arc::new(res) as ArrayRef) - }), + // Decimal types + DataType::Decimal128(_, _) => Ok(make_decimal_abs_function!(Decimal128Array)), + DataType::Decimal256(_, _) => Ok(make_decimal_abs_function!(Decimal256Array)), other => not_impl_err!("Unsupported data type {other:?} for function abs"), } } +/// abs() SQL function implementation +pub fn abs_invoke(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return internal_err!("abs function requires 1 argument, got {}", args.len()); + } + + let input_data_type = args[0].data_type(); + let abs_fun = create_abs_function(input_data_type)?; + + abs_fun(args) +} + #[cfg(test)] mod tests { @@ -751,7 +769,8 @@ mod tests { let args = vec![ColumnarValue::Array(Arc::new(NullArray::new(1)))]; let array = random(&args) .expect("failed to initialize function random") - .into_array(1); + .into_array(1) + .expect("Failed to convert to array"); let floats = as_float64_array(&array).expect("failed to initialize function random"); diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index 773eac40dc8a..301f12e9aa2e 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -15,14 +15,95 @@ // specific language governing permissions and limitations // under the License. -//! [`Partitioning`] and [`Distribution`] for physical expressions +//! [`Partitioning`] and [`Distribution`] for `ExecutionPlans` use std::fmt; use std::sync::Arc; -use crate::{expr_list_eq_strict_order, EquivalenceProperties, PhysicalExpr}; +use crate::{physical_exprs_equal, EquivalenceProperties, PhysicalExpr}; -/// Partitioning schemes supported by operators. +/// Output partitioning supported by [`ExecutionPlan`]s. +/// +/// When `executed`, `ExecutionPlan`s produce one or more independent stream of +/// data batches in parallel, referred to as partitions. The streams are Rust +/// `async` [`Stream`]s (a special kind of future). The number of output +/// partitions varies based on the input and the operation performed. +/// +/// For example, an `ExecutionPlan` that has output partitioning of 3 will +/// produce 3 distinct output streams as the result of calling +/// `ExecutionPlan::execute(0)`, `ExecutionPlan::execute(1)`, and +/// `ExecutionPlan::execute(2)`, as shown below: +/// +/// ```text +/// ... ... ... +/// ... ▲ ▲ ▲ +/// │ │ │ +/// ▲ │ │ │ +/// │ │ │ │ +/// │ ┌───┴────┐ ┌───┴────┐ ┌───┴────┐ +/// ┌────────────────────┐ │ Stream │ │ Stream │ │ Stream │ +/// │ ExecutionPlan │ │ (0) │ │ (1) │ │ (2) │ +/// └────────────────────┘ └────────┘ └────────┘ └────────┘ +/// ▲ ▲ ▲ ▲ +/// │ │ │ │ +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │ │ │ +/// Input │ │ │ │ +/// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │ │ │ +/// ▲ ┌ ─ ─ ─ ─ ┌ ─ ─ ─ ─ ┌ ─ ─ ─ ─ +/// │ Input │ Input │ Input │ +/// │ │ Stream │ Stream │ Stream +/// (0) │ (1) │ (2) │ +/// ... └ ─ ▲ ─ ─ └ ─ ▲ ─ ─ └ ─ ▲ ─ ─ +/// │ │ │ +/// │ │ │ +/// │ │ │ +/// +/// ExecutionPlan with 1 input 3 (async) streams, one for each +/// that has 3 partitions, which itself output partition +/// has 3 output partitions +/// ``` +/// +/// It is common (but not required) that an `ExecutionPlan` has the same number +/// of input partitions as output partitions. However, some plans have different +/// numbers such as the `RepartitionExec` that redistributes batches from some +/// number of inputs to some number of outputs +/// +/// ```text +/// ... ... ... ... +/// +/// ▲ ▲ ▲ +/// ▲ │ │ │ +/// │ │ │ │ +/// ┌────────┴───────────┐ │ │ │ +/// │ RepartitionExec │ ┌────┴───┐ ┌────┴───┐ ┌────┴───┐ +/// └────────────────────┘ │ Stream │ │ Stream │ │ Stream │ +/// ▲ │ (0) │ │ (1) │ │ (2) │ +/// │ └────────┘ └────────┘ └────────┘ +/// │ ▲ ▲ ▲ +/// ... │ │ │ +/// └──────────┐│┌──────────┘ +/// │││ +/// │││ +/// RepartitionExec with one input +/// that has 3 partitions, but 3 (async) streams, that internally +/// itself has only 1 output partition pull from the same input stream +/// ... +/// ``` +/// +/// # Additional Examples +/// +/// A simple `FileScanExec` might produce one output stream (partition) for each +/// file (note the actual DataFusion file scaners can read individual files in +/// parallel, potentially producing multiple partitions per file) +/// +/// Plans such as `SortPreservingMerge` produce a single output stream +/// (1 output partition) by combining some number of input streams (input partitions) +/// +/// Plans such as `FilterExec` produce the same number of output streams +/// (partitions) as input streams (partitions). +/// +/// [`ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html +/// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html #[derive(Debug, Clone)] pub enum Partitioning { /// Allocate batches using a round-robin algorithm and the specified number of partitions @@ -66,7 +147,7 @@ impl Partitioning { pub fn satisfy EquivalenceProperties>( &self, required: Distribution, - equal_properties: F, + eq_properties: F, ) -> bool { match required { Distribution::UnspecifiedDistribution => true, @@ -78,31 +159,28 @@ impl Partitioning { // then we need to have the partition count and hash functions validation. Partitioning::Hash(partition_exprs, _) => { let fast_match = - expr_list_eq_strict_order(&required_exprs, partition_exprs); + physical_exprs_equal(&required_exprs, partition_exprs); // If the required exprs do not match, need to leverage the eq_properties provided by the child - // and normalize both exprs based on the eq_properties + // and normalize both exprs based on the equivalent groups. if !fast_match { - let eq_properties = equal_properties(); - let eq_classes = eq_properties.classes(); - if !eq_classes.is_empty() { + let eq_properties = eq_properties(); + let eq_groups = eq_properties.eq_group(); + if !eq_groups.is_empty() { let normalized_required_exprs = required_exprs .iter() - .map(|e| eq_properties.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(e.clone())) .collect::>(); let normalized_partition_exprs = partition_exprs .iter() - .map(|e| eq_properties.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(e.clone())) .collect::>(); - expr_list_eq_strict_order( + return physical_exprs_equal( &normalized_required_exprs, &normalized_partition_exprs, - ) - } else { - fast_match + ); } - } else { - fast_match } + fast_match } _ => false, } @@ -120,7 +198,7 @@ impl PartialEq for Partitioning { Partitioning::RoundRobinBatch(count2), ) if count1 == count2 => true, (Partitioning::Hash(exprs1, count1), Partitioning::Hash(exprs2, count2)) - if expr_list_eq_strict_order(exprs1, exprs2) && (count1 == count2) => + if physical_exprs_equal(exprs1, exprs2) && (count1 == count2) => { true } @@ -129,7 +207,8 @@ impl PartialEq for Partitioning { } } -/// Distribution schemes +/// How data is distributed amongst partitions. See [`Partitioning`] for more +/// details. #[derive(Debug, Clone)] pub enum Distribution { /// Unspecified distribution @@ -142,7 +221,7 @@ pub enum Distribution { } impl Distribution { - /// Creates a Partitioning for this Distribution to satisfy itself + /// Creates a `Partitioning` that satisfies this `Distribution` pub fn create_partitioning(&self, partition_count: usize) -> Partitioning { match self { Distribution::UnspecifiedDistribution => { @@ -158,15 +237,13 @@ impl Distribution { #[cfg(test)] mod tests { - use crate::expressions::Column; + use std::sync::Arc; use super::*; - use arrow::datatypes::DataType; - use arrow::datatypes::Field; - use arrow::datatypes::Schema; - use datafusion_common::Result; + use crate::expressions::Column; - use std::sync::Arc; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; #[test] fn partitioning_satisfy_distribution() -> Result<()> { diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 81702d8bfae0..a8d1e3638a17 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -15,7 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::intervals::Interval; +use std::any::Any; +use std::fmt::{Debug, Display}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + use crate::sort_properties::SortProperties; use crate::utils::scatter; @@ -25,17 +29,15 @@ use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::utils::DataPtr; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; -use std::any::Any; -use std::fmt::{Debug, Display}; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; +use itertools::izip; /// Expression that can be evaluated against a RecordBatch /// A Physical expression knows its type, nullability and how to evaluate itself. pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { - /// Returns the physical expression as [`Any`](std::any::Any) so that it can be + /// Returns the physical expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; /// Get the data type of this expression, given the schema of the input @@ -54,13 +56,12 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { let tmp_batch = filter_record_batch(batch, selection)?; let tmp_result = self.evaluate(&tmp_batch)?; - // All values from the `selection` filter are true. + if batch.num_rows() == tmp_batch.num_rows() { - return Ok(tmp_result); - } - if let ColumnarValue::Array(a) = tmp_result { - let result = scatter(selection, a.as_ref())?; - Ok(ColumnarValue::Array(result)) + // All values from the `selection` filter are true. + Ok(tmp_result) + } else if let ColumnarValue::Array(a) = tmp_result { + scatter(selection, a.as_ref()).map(ColumnarValue::Array) } else { Ok(tmp_result) } @@ -75,21 +76,53 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { children: Vec>, ) -> Result>; - /// Computes bounds for the expression using interval arithmetic. + /// Computes the output interval for the expression, given the input + /// intervals. + /// + /// # Arguments + /// + /// * `children` are the intervals for the children (inputs) of this + /// expression. + /// + /// # Example + /// + /// If the expression is `a + b`, and the input intervals are `a: [1, 2]` + /// and `b: [3, 4]`, then the output interval would be `[4, 6]`. fn evaluate_bounds(&self, _children: &[&Interval]) -> Result { not_impl_err!("Not implemented for {self}") } - /// Updates/shrinks bounds for the expression using interval arithmetic. - /// If constraint propagation reveals an infeasibility, returns [None] for - /// the child causing infeasibility. If none of the children intervals - /// change, may return an empty vector instead of cloning `children`. + /// Updates bounds for child expressions, given a known interval for this + /// expression. + /// + /// This is used to propagate constraints down through an expression tree. + /// + /// # Arguments + /// + /// * `interval` is the currently known interval for this expression. + /// * `children` are the current intervals for the children of this expression. + /// + /// # Returns + /// + /// A `Vec` of new intervals for the children, in order. + /// + /// If constraint propagation reveals an infeasibility for any child, returns + /// [`None`]. If none of the children intervals change as a result of propagation, + /// may return an empty vector instead of cloning `children`. This is the default + /// (and conservative) return value. + /// + /// # Example + /// + /// If the expression is `a + b`, the current `interval` is `[4, 5]` and the + /// inputs `a` and `b` are respectively given as `[0, 2]` and `[-∞, 4]`, then + /// propagation would would return `[0, 2]` and `[2, 4]` as `b` must be at + /// least `2` to make the output at least `4`. fn propagate_constraints( &self, _interval: &Interval, _children: &[&Interval], - ) -> Result>> { - not_impl_err!("Not implemented for {self}") + ) -> Result>> { + Ok(Some(vec![])) } /// Update the hash `state` with this expression requirements from @@ -182,8 +215,8 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { } } -/// It is similar to contains method of vector. -/// Finds whether `expr` is among `physical_exprs`. +/// This function is similar to the `contains` method of `Vec`. It finds +/// whether `expr` is among `physical_exprs`. pub fn physical_exprs_contains( physical_exprs: &[Arc], expr: &Arc, @@ -192,3 +225,216 @@ pub fn physical_exprs_contains( .iter() .any(|physical_expr| physical_expr.eq(expr)) } + +/// Checks whether the given physical expression slices are equal. +pub fn physical_exprs_equal( + lhs: &[Arc], + rhs: &[Arc], +) -> bool { + lhs.len() == rhs.len() && izip!(lhs, rhs).all(|(lhs, rhs)| lhs.eq(rhs)) +} + +/// Checks whether the given physical expression slices are equal in the sense +/// of bags (multi-sets), disregarding their orderings. +pub fn physical_exprs_bag_equal( + lhs: &[Arc], + rhs: &[Arc], +) -> bool { + // TODO: Once we can use `HashMap`s with `Arc`, this + // function should use a `HashMap` to reduce computational complexity. + if lhs.len() == rhs.len() { + let mut rhs_vec = rhs.to_vec(); + for expr in lhs { + if let Some(idx) = rhs_vec.iter().position(|e| expr.eq(e)) { + rhs_vec.swap_remove(idx); + } else { + return false; + } + } + true + } else { + false + } +} + +/// This utility function removes duplicates from the given `exprs` vector. +/// Note that this function does not necessarily preserve its input ordering. +pub fn deduplicate_physical_exprs(exprs: &mut Vec>) { + // TODO: Once we can use `HashSet`s with `Arc`, this + // function should use a `HashSet` to reduce computational complexity. + // See issue: https://github.com/apache/arrow-datafusion/issues/8027 + let mut idx = 0; + while idx < exprs.len() { + let mut rest_idx = idx + 1; + while rest_idx < exprs.len() { + if exprs[idx].eq(&exprs[rest_idx]) { + exprs.swap_remove(rest_idx); + } else { + rest_idx += 1; + } + } + idx += 1; + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::expressions::{Column, Literal}; + use crate::physical_expr::{ + deduplicate_physical_exprs, physical_exprs_bag_equal, physical_exprs_contains, + physical_exprs_equal, PhysicalExpr, + }; + + use datafusion_common::ScalarValue; + + #[test] + fn test_physical_exprs_contains() { + let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + as Arc; + let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + as Arc; + let lit4 = + Arc::new(Literal::new(ScalarValue::Int32(Some(4)))) as Arc; + let lit2 = + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let lit1 = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; + + // lit(true), lit(false), lit(4), lit(2), Col(a), Col(b) + let physical_exprs: Vec> = vec![ + lit_true.clone(), + lit_false.clone(), + lit4.clone(), + lit2.clone(), + col_a_expr.clone(), + col_b_expr.clone(), + ]; + // below expressions are inside physical_exprs + assert!(physical_exprs_contains(&physical_exprs, &lit_true)); + assert!(physical_exprs_contains(&physical_exprs, &lit2)); + assert!(physical_exprs_contains(&physical_exprs, &col_b_expr)); + + // below expressions are not inside physical_exprs + assert!(!physical_exprs_contains(&physical_exprs, &col_c_expr)); + assert!(!physical_exprs_contains(&physical_exprs, &lit1)); + } + + #[test] + fn test_physical_exprs_equal() { + let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + as Arc; + let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + as Arc; + let lit1 = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let lit2 = + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + + let vec1 = vec![lit_true.clone(), lit_false.clone()]; + let vec2 = vec![lit_true.clone(), col_b_expr.clone()]; + let vec3 = vec![lit2.clone(), lit1.clone()]; + let vec4 = vec![lit_true.clone(), lit_false.clone()]; + + // these vectors are same + assert!(physical_exprs_equal(&vec1, &vec1)); + assert!(physical_exprs_equal(&vec1, &vec4)); + assert!(physical_exprs_bag_equal(&vec1, &vec1)); + assert!(physical_exprs_bag_equal(&vec1, &vec4)); + + // these vectors are different + assert!(!physical_exprs_equal(&vec1, &vec2)); + assert!(!physical_exprs_equal(&vec1, &vec3)); + assert!(!physical_exprs_bag_equal(&vec1, &vec2)); + assert!(!physical_exprs_bag_equal(&vec1, &vec3)); + } + + #[test] + fn test_physical_exprs_set_equal() { + let list1: Vec> = vec![ + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + ]; + let list2: Vec> = vec![ + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("a", 0)), + ]; + assert!(!physical_exprs_bag_equal( + list1.as_slice(), + list2.as_slice() + )); + assert!(!physical_exprs_bag_equal( + list2.as_slice(), + list1.as_slice() + )); + assert!(!physical_exprs_equal(list1.as_slice(), list2.as_slice())); + assert!(!physical_exprs_equal(list2.as_slice(), list1.as_slice())); + + let list3: Vec> = vec![ + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("c", 2)), + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + ]; + let list4: Vec> = vec![ + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("c", 2)), + Arc::new(Column::new("a", 0)), + ]; + assert!(physical_exprs_bag_equal(list3.as_slice(), list4.as_slice())); + assert!(physical_exprs_bag_equal(list4.as_slice(), list3.as_slice())); + assert!(physical_exprs_bag_equal(list3.as_slice(), list3.as_slice())); + assert!(physical_exprs_bag_equal(list4.as_slice(), list4.as_slice())); + assert!(!physical_exprs_equal(list3.as_slice(), list4.as_slice())); + assert!(!physical_exprs_equal(list4.as_slice(), list3.as_slice())); + assert!(physical_exprs_bag_equal(list3.as_slice(), list3.as_slice())); + assert!(physical_exprs_bag_equal(list4.as_slice(), list4.as_slice())); + } + + #[test] + fn test_deduplicate_physical_exprs() { + let lit_true = &(Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + as Arc); + let lit_false = &(Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + as Arc); + let lit4 = &(Arc::new(Literal::new(ScalarValue::Int32(Some(4)))) + as Arc); + let lit2 = &(Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) + as Arc); + let col_a_expr = &(Arc::new(Column::new("a", 0)) as Arc); + let col_b_expr = &(Arc::new(Column::new("b", 1)) as Arc); + + // First vector in the tuple is arguments, second one is the expected value. + let test_cases = vec![ + // ---------- TEST CASE 1----------// + ( + vec![ + lit_true, lit_false, lit4, lit2, col_a_expr, col_a_expr, col_b_expr, + lit_true, lit2, + ], + vec![lit_true, lit_false, lit4, lit2, col_a_expr, col_b_expr], + ), + // ---------- TEST CASE 2----------// + ( + vec![lit_true, lit_true, lit_false, lit4], + vec![lit_true, lit4, lit_false], + ), + ]; + for (exprs, expected) in test_cases { + let mut exprs = exprs.into_iter().cloned().collect::>(); + let expected = expected.into_iter().cloned().collect::>(); + deduplicate_physical_exprs(&mut exprs); + assert!(physical_exprs_equal(&exprs, &expected)); + } + } +} diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 9a74c2ca64d1..9c212cb81f6b 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -29,10 +29,10 @@ use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction, ScalarUDF}; +use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction}; use datafusion_expr::{ binary_expr, Between, BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Like, - Operator, TryCast, + Operator, ScalarFunctionDefinition, TryCast, }; use std::sync::Arc; @@ -348,35 +348,37 @@ pub fn create_physical_expr( ))) } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let physical_args = args + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let mut physical_args = args .iter() .map(|e| { create_physical_expr(e, input_dfschema, input_schema, execution_props) }) .collect::>>()?; - functions::create_physical_expr( - fun, - &physical_args, - input_schema, - execution_props, - ) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - let mut physical_args = vec![]; - for e in args { - physical_args.push(create_physical_expr( - e, - input_dfschema, - input_schema, - execution_props, - )?); - } - // udfs with zero params expect null array as input - if args.is_empty() { - physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + functions::create_physical_expr( + fun, + &physical_args, + input_schema, + execution_props, + ) + } + ScalarFunctionDefinition::UDF(fun) => { + // udfs with zero params expect null array as input + if args.is_empty() { + physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); + } + udf::create_physical_expr( + fun.clone().as_ref(), + &physical_args, + input_schema, + ) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } } - udf::create_physical_expr(fun.clone().as_ref(), &physical_args, input_schema) } Expr::Between(Between { expr, @@ -448,3 +450,37 @@ pub fn create_physical_expr( } } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::{DFSchema, Result}; + use datafusion_expr::{col, left, Literal}; + + #[test] + fn test_create_physical_expr_scalar_input_output() -> Result<()> { + let expr = col("letter").eq(left("APACHE".lit(), 1i64.lit())); + + let schema = Schema::new(vec![Field::new("letter", DataType::Utf8, false)]); + let df_schema = DFSchema::try_from_qualified_schema("data", &schema)?; + let p = create_physical_expr(&expr, &df_schema, &schema, &ExecutionProps::new())?; + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(StringArray::from_iter_values(vec![ + "A", "B", "C", "D", + ]))], + )?; + let result = p.evaluate(&batch)?; + let result = result.into_array(4).expect("Failed to convert to array"); + + assert_eq!( + &result, + &(Arc::new(BooleanArray::from(vec![true, false, false, false,])) as ArrayRef) + ); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/regex_expressions.rs b/datafusion/physical-expr/src/regex_expressions.rs index 41cd01949595..b778fd86c24b 100644 --- a/datafusion/physical-expr/src/regex_expressions.rs +++ b/datafusion/physical-expr/src/regex_expressions.rs @@ -25,8 +25,9 @@ use arrow::array::{ new_null_array, Array, ArrayDataBuilder, ArrayRef, BufferBuilder, GenericStringArray, OffsetSizeTrait, }; -use arrow::compute; -use datafusion_common::plan_err; +use arrow_array::builder::{GenericStringBuilder, ListBuilder}; +use arrow_schema::ArrowError; +use datafusion_common::{arrow_datafusion_err, plan_err}; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; @@ -58,7 +59,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { 2 => { let values = as_generic_string_array::(&args[0])?; let regex = as_generic_string_array::(&args[1])?; - compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError) + _regexp_match(values, regex, None).map_err(|e| arrow_datafusion_err!(e)) } 3 => { let values = as_generic_string_array::(&args[0])?; @@ -69,7 +70,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { Some(f) if f.iter().any(|s| s == Some("g")) => { plan_err!("regexp_match() does not support the \"global\" option") }, - _ => compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError), + _ => _regexp_match(values, regex, flags).map_err(|e| arrow_datafusion_err!(e)), } } other => internal_err!( @@ -78,6 +79,83 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { } } +/// TODO: Remove this once it is included in arrow-rs new release. +/// +fn _regexp_match( + array: &GenericStringArray, + regex_array: &GenericStringArray, + flags_array: Option<&GenericStringArray>, +) -> std::result::Result { + let mut patterns: std::collections::HashMap = + std::collections::HashMap::new(); + let builder: GenericStringBuilder = + GenericStringBuilder::with_capacity(0, 0); + let mut list_builder = ListBuilder::new(builder); + + let complete_pattern = match flags_array { + Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( + |(pattern, flags)| { + pattern.map(|pattern| match flags { + Some(value) => format!("(?{value}){pattern}"), + None => pattern.to_string(), + }) + }, + )) as Box>>, + None => Box::new( + regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; + + array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + // Required for Postgres compatibility: + // SELECT regexp_match('foobarbequebaz', ''); = {""} + (Some(_), Some(pattern)) if pattern == *"" => { + list_builder.values().append_value(""); + list_builder.append(true); + } + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re, + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + patterns.insert(pattern.clone(), re); + patterns.get(&pattern).unwrap() + } + }; + match re.captures(value) { + Some(caps) => { + let mut iter = caps.iter(); + if caps.len() > 1 { + iter.next(); + } + for m in iter.flatten() { + list_builder.values().append_value(m.as_str()); + } + + list_builder.append(true); + } + None => list_builder.append(false), + } + } + _ => list_builder.append(false), + } + Ok(()) + }) + .collect::, ArrowError>>()?; + Ok(Arc::new(list_builder.finish())) +} + /// replace POSIX capture groups (like \1) with Rust Regex group (like ${1}) /// used by regexp_replace fn regex_replace_posix_groups(replacement: &str) -> String { @@ -116,12 +194,12 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result // if patterns hashmap already has regexp then use else else create and return let re = match patterns.get(pattern) { - Some(re) => Ok(re.clone()), + Some(re) => Ok(re), None => { match Regex::new(pattern) { Ok(re) => { - patterns.insert(pattern.to_string(), re.clone()); - Ok(re) + patterns.insert(pattern.to_string(), re); + Ok(patterns.get(pattern).unwrap()) }, Err(err) => Err(DataFusionError::External(Box::new(err))), } @@ -162,12 +240,12 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result // if patterns hashmap already has regexp then use else else create and return let re = match patterns.get(&pattern) { - Some(re) => Ok(re.clone()), + Some(re) => Ok(re), None => { match Regex::new(pattern.as_str()) { Ok(re) => { - patterns.insert(pattern, re.clone()); - Ok(re) + patterns.insert(pattern.clone(), re); + Ok(patterns.get(&pattern).unwrap()) }, Err(err) => Err(DataFusionError::External(Box::new(err))), } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 517553f90fb2..0a9d69720e19 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -29,24 +29,23 @@ //! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed //! to a function that supports f64, it is coerced to f64. +use std::any::Any; +use std::fmt::{self, Debug, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + use crate::functions::out_ordering; -use crate::functions::FuncMonotonicity; -use crate::physical_expr::down_cast_any_ref; +use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal}; use crate::sort_properties::SortProperties; -use crate::utils::expr_list_eq_strict_order; use crate::PhysicalExpr; + use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::Result; -use datafusion_expr::expr_vec_fmt; -use datafusion_expr::BuiltinScalarFunction; -use datafusion_expr::ColumnarValue; -use datafusion_expr::ScalarFunctionImplementation; -use std::any::Any; -use std::fmt::Debug; -use std::fmt::{self, Formatter}; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; +use datafusion_expr::{ + expr_vec_fmt, BuiltinScalarFunction, ColumnarValue, FuncMonotonicity, + ScalarFunctionImplementation, +}; /// Physical expression of a scalar function pub struct ScalarFunctionExpr { @@ -78,14 +77,14 @@ impl ScalarFunctionExpr { name: &str, fun: ScalarFunctionImplementation, args: Vec>, - return_type: &DataType, + return_type: DataType, monotonicity: Option, ) -> Self { Self { fun, name: name.to_owned(), args, - return_type: return_type.clone(), + return_type, monotonicity, } } @@ -109,6 +108,11 @@ impl ScalarFunctionExpr { pub fn return_type(&self) -> &DataType { &self.return_type } + + /// Monotonicity information of the function + pub fn monotonicity(&self) -> &Option { + &self.monotonicity + } } impl fmt::Display for ScalarFunctionExpr { @@ -137,7 +141,10 @@ impl PhysicalExpr for ScalarFunctionExpr { let inputs = match (self.args.len(), self.name.parse::()) { // MakeArray support zero argument but has the different behavior from the array with one null. (0, Ok(scalar_fun)) - if scalar_fun.supports_zero_argument() + if scalar_fun + .signature() + .type_signature + .supports_zero_argument() && scalar_fun != BuiltinScalarFunction::MakeArray => { vec![ColumnarValue::create_null_array(batch.num_rows())] @@ -166,7 +173,7 @@ impl PhysicalExpr for ScalarFunctionExpr { &self.name, self.fun.clone(), children, - self.return_type(), + self.return_type().clone(), self.monotonicity.clone(), ))) } @@ -194,7 +201,7 @@ impl PartialEq for ScalarFunctionExpr { .downcast_ref::() .map(|x| { self.name == x.name - && expr_list_eq_strict_order(&self.args, &x.args) + && physical_exprs_equal(&self.args, &x.args) && self.return_type == x.return_type }) .unwrap_or(false) diff --git a/datafusion/physical-expr/src/sort_expr.rs b/datafusion/physical-expr/src/sort_expr.rs index 83d32dfeec17..914d76f9261a 100644 --- a/datafusion/physical-expr/src/sort_expr.rs +++ b/datafusion/physical-expr/src/sort_expr.rs @@ -17,6 +17,7 @@ //! Sort expressions +use std::fmt::Display; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -24,8 +25,8 @@ use crate::PhysicalExpr; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; -use datafusion_common::plan_err; -use datafusion_common::{DataFusionError, Result}; +use arrow_schema::Schema; +use datafusion_common::Result; use datafusion_expr::ColumnarValue; /// Represents Sort operation for a column in a RecordBatch @@ -64,11 +65,7 @@ impl PhysicalSortExpr { let value_to_sort = self.expr.evaluate(batch)?; let array_to_sort = match value_to_sort { ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => { - return plan_err!( - "Sort operation is not applicable to scalar value {scalar}" - ); - } + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(batch.num_rows())?, }; Ok(SortColumn { values: array_to_sort, @@ -76,18 +73,46 @@ impl PhysicalSortExpr { }) } - /// Check whether sort expression satisfies [`PhysicalSortRequirement`]. - /// - /// If sort options is Some in `PhysicalSortRequirement`, `expr` - /// and `options` field are compared for equality. - /// - /// If sort options is None in `PhysicalSortRequirement`, only - /// `expr` is compared for equality. - pub fn satisfy(&self, requirement: &PhysicalSortRequirement) -> bool { + /// Checks whether this sort expression satisfies the given `requirement`. + /// If sort options are unspecified in `requirement`, only expressions are + /// compared for inequality. + pub fn satisfy( + &self, + requirement: &PhysicalSortRequirement, + schema: &Schema, + ) -> bool { + // If the column is not nullable, NULLS FIRST/LAST is not important. + let nullable = self.expr.nullable(schema).unwrap_or(true); self.expr.eq(&requirement.expr) - && requirement - .options - .map_or(true, |opts| self.options == opts) + && if nullable { + requirement + .options + .map_or(true, |opts| self.options == opts) + } else { + requirement + .options + .map_or(true, |opts| self.options.descending == opts.descending) + } + } + + /// Returns a [`Display`]able list of `PhysicalSortExpr`. + pub fn format_list(input: &[PhysicalSortExpr]) -> impl Display + '_ { + struct DisplayableList<'a>(&'a [PhysicalSortExpr]); + impl<'a> Display for DisplayableList<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let mut first = true; + for sort_expr in self.0 { + if first { + first = false; + } else { + write!(f, ",")?; + } + write!(f, "{}", sort_expr)?; + } + Ok(()) + } + } + DisplayableList(input) } } @@ -227,11 +252,18 @@ fn to_str(options: &SortOptions) -> &str { } } -///`LexOrdering` is a type alias for lexicographical ordering definition`Vec` +///`LexOrdering` is an alias for the type `Vec`, which represents +/// a lexicographical ordering. pub type LexOrdering = Vec; -///`LexOrderingRef` is a type alias for lexicographical ordering reference &`[PhysicalSortExpr]` +///`LexOrderingRef` is an alias for the type &`[PhysicalSortExpr]`, which represents +/// a reference to a lexicographical ordering. pub type LexOrderingRef<'a> = &'a [PhysicalSortExpr]; -///`LexOrderingReq` is a type alias for lexicographical ordering requirement definition`Vec` -pub type LexOrderingReq = Vec; +///`LexRequirement` is an alias for the type `Vec`, which +/// represents a lexicographical ordering requirement. +pub type LexRequirement = Vec; + +///`LexRequirementRef` is an alias for the type &`[PhysicalSortRequirement]`, which +/// represents a reference to a lexicographical ordering requirement. +pub type LexRequirementRef<'a> = &'a [PhysicalSortRequirement]; diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs index 001b86e60a86..0205f85dced4 100644 --- a/datafusion/physical-expr/src/sort_properties.rs +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -15,19 +15,14 @@ // specific language governing permissions and limitations // under the License. +use std::borrow::Cow; use std::{ops::Neg, sync::Arc}; -use crate::expressions::Column; -use crate::utils::get_indices_of_matching_sort_exprs_with_order_eq; -use crate::{ - EquivalenceProperties, OrderingEquivalenceProperties, PhysicalExpr, PhysicalSortExpr, -}; - use arrow_schema::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; -use datafusion_common::Result; -use itertools::Itertools; +use crate::PhysicalExpr; +use datafusion_common::tree_node::TreeNode; +use datafusion_common::Result; /// To propagate [`SortOptions`] across the [`PhysicalExpr`], it is insufficient /// to simply use `Option`: There must be a differentiation between @@ -40,11 +35,12 @@ use itertools::Itertools; /// sorted data; however the ((a_ordered + 999) + c_ordered) expression can. Therefore, /// we need two different variants for literals and unordered columns as literals are /// often more ordering-friendly under most mathematical operations. -#[derive(PartialEq, Debug, Clone, Copy)] +#[derive(PartialEq, Debug, Clone, Copy, Default)] pub enum SortProperties { /// Use the ordinary [`SortOptions`] struct to represent ordered data: Ordered(SortOptions), // This alternative represents unordered data: + #[default] Unordered, // Singleton is used for single-valued literal numbers: Singleton, @@ -103,7 +99,7 @@ impl SortProperties { } } - pub fn and(&self, rhs: &Self) -> Self { + pub fn and_or(&self, rhs: &Self) -> Self { match (self, rhs) { (Self::Ordered(lhs), Self::Ordered(rhs)) if lhs.descending == rhs.descending => @@ -152,126 +148,47 @@ impl Neg for SortProperties { /// It encapsulates the orderings (`state`) associated with the expression (`expr`), and /// orderings of the children expressions (`children_states`). The [`ExprOrdering`] of a parent /// expression is determined based on the [`ExprOrdering`] states of its children expressions. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ExprOrdering { pub expr: Arc, - pub state: Option, - pub children_states: Option>, + pub state: SortProperties, + pub children: Vec, } impl ExprOrdering { + /// Creates a new [`ExprOrdering`] with [`SortProperties::Unordered`] states + /// for `expr` and its children. pub fn new(expr: Arc) -> Self { + let children = expr.children(); Self { expr, - state: None, - children_states: None, + state: Default::default(), + children: children.into_iter().map(Self::new).collect(), } } - pub fn children(&self) -> Vec { - self.expr - .children() - .into_iter() - .map(|e| ExprOrdering::new(e)) - .collect() - } - - pub fn new_with_children( - children_states: Vec, - parent_expr: Arc, - ) -> Self { - Self { - expr: parent_expr, - state: None, - children_states: Some(children_states), - } + /// Get a reference to each child state. + pub fn children_state(&self) -> Vec { + self.children.iter().map(|c| c.state).collect() } } impl TreeNode for ExprOrdering { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children.iter().map(Cow::Borrowed).collect() } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - Ok(ExprOrdering::new_with_children( - children - .into_iter() - .map(transform) - .map_ok(|c| c.state.unwrap_or(SortProperties::Unordered)) - .collect::>>()?, - self.expr, - )) + if !self.children.is_empty() { + self.children = self + .children + .into_iter() + .map(transform) + .collect::>()?; } + Ok(self) } } - -/// Calculates the [`SortProperties`] of a given [`ExprOrdering`] node. -/// The node is either a leaf node, or an intermediate node: -/// - If it is a leaf node, the children states are `None`. We directly find -/// the order of the node by looking at the given sort expression and equivalence -/// properties if it is a `Column` leaf, or we mark it as unordered. In the case -/// of a `Literal` leaf, we mark it as singleton so that it can cooperate with -/// some ordered columns at the upper steps. -/// - If it is an intermediate node, the children states matter. Each `PhysicalExpr` -/// and operator has its own rules about how to propagate the children orderings. -/// However, before the children order propagation, it is checked that whether -/// the intermediate node can be directly matched with the sort expression. If there -/// is a match, the sort expression emerges at that node immediately, discarding -/// the order coming from the children. -pub fn update_ordering( - mut node: ExprOrdering, - sort_expr: &PhysicalSortExpr, - equal_properties: &EquivalenceProperties, - ordering_equal_properties: &OrderingEquivalenceProperties, -) -> Result> { - // If we can directly match a sort expr with the current node, we can set - // its state and return early. - // TODO: If there is a PhysicalExpr other than a Column at this node (e.g. - // a BinaryExpr like a + b), and there is an ordering equivalence of - // it (let's say like c + d), we actually can find it at this step. - if sort_expr.expr.eq(&node.expr) { - node.state = Some(SortProperties::Ordered(sort_expr.options)); - return Ok(Transformed::Yes(node)); - } - - if let Some(children_sort_options) = &node.children_states { - // We have an intermediate (non-leaf) node, account for its children: - node.state = Some(node.expr.get_ordering(children_sort_options)); - } else if let Some(column) = node.expr.as_any().downcast_ref::() { - // We have a Column, which is one of the two possible leaf node types: - node.state = get_indices_of_matching_sort_exprs_with_order_eq( - &[sort_expr.clone()], - &[column.clone()], - equal_properties, - ordering_equal_properties, - ) - .map(|(sort_options, _)| { - SortProperties::Ordered(SortOptions { - descending: sort_options[0].descending, - nulls_first: sort_options[0].nulls_first, - }) - }); - } else { - // We have a Literal, which is the other possible leaf node type: - node.state = Some(node.expr.get_ordering(&[])); - } - Ok(Transformed::Yes(node)) -} diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index e6a3d5c331a5..7d9fecf61407 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -23,11 +23,12 @@ use arrow::{ array::{ - Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, OffsetSizeTrait, - StringArray, + Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, Int64Array, + OffsetSizeTrait, StringArray, }, datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, }; +use datafusion_common::utils::datafusion_strsim; use datafusion_common::{ cast::{ as_generic_string_array, as_int64_array, as_primitive_array, as_string_array, @@ -36,8 +37,11 @@ use datafusion_common::{ }; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; -use std::iter; use std::sync::Arc; +use std::{ + fmt::{Display, Formatter}, + iter, +}; use uuid::Uuid; /// applies a unary expression to `args[0]` that is expected to be downcastable to @@ -132,53 +136,6 @@ pub fn ascii(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -/// Removes the longest string containing only characters in characters (a space by default) from the start and end of string. -/// btrim('xyxtrimyyx', 'xyz') = 'trim' -pub fn btrim(args: &[ArrayRef]) -> Result { - match args.len() { - 1 => { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - string.trim_start_matches(' ').trim_end_matches(' ') - }) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (None, _) => None, - (_, None) => None, - (Some(string), Some(characters)) => { - let chars: Vec = characters.chars().collect(); - Some( - string - .trim_start_matches(&chars[..]) - .trim_end_matches(&chars[..]), - ) - } - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - other => internal_err!( - "btrim was called with {other} arguments. It requires at least 1 and at most 2." - ), - } -} - /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' pub fn chr(args: &[ArrayRef]) -> Result { @@ -345,44 +302,95 @@ pub fn lower(args: &[ColumnarValue]) -> Result { handle(args, |string| string.to_ascii_lowercase(), "lower") } -/// Removes the longest string containing only characters in characters (a space by default) from the start of string. -/// ltrim('zzzytest', 'xyz') = 'test' -pub fn ltrim(args: &[ArrayRef]) -> Result { +enum TrimType { + Left, + Right, + Both, +} + +impl Display for TrimType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + TrimType::Left => write!(f, "ltrim"), + TrimType::Right => write!(f, "rtrim"), + TrimType::Both => write!(f, "btrim"), + } + } +} + +fn general_trim( + args: &[ArrayRef], + trim_type: TrimType, +) -> Result { + let func = match trim_type { + TrimType::Left => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_start_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Right => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Both => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>( + str::trim_start_matches::<&[char]>(input, pattern.as_ref()), + pattern.as_ref(), + ) + }, + }; + + let string_array = as_generic_string_array::(&args[0])?; + match args.len() { 1 => { - let string_array = as_generic_string_array::(&args[0])?; - let result = string_array .iter() - .map(|string| string.map(|string: &str| string.trim_start_matches(' '))) + .map(|string| string.map(|string: &str| func(string, " "))) .collect::>(); Ok(Arc::new(result) as ArrayRef) } 2 => { - let string_array = as_generic_string_array::(&args[0])?; let characters_array = as_generic_string_array::(&args[1])?; let result = string_array .iter() .zip(characters_array.iter()) .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => { - let chars: Vec = characters.chars().collect(); - Some(string.trim_start_matches(&chars[..])) - } + (Some(string), Some(characters)) => Some(func(string, characters)), _ => None, }) .collect::>(); Ok(Arc::new(result) as ArrayRef) } - other => internal_err!( - "ltrim was called with {other} arguments. It requires at least 1 and at most 2." - ), + other => { + internal_err!( + "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." + ) + } } } +/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. +/// btrim('xyxtrimyyx', 'xyz') = 'trim' +pub fn btrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Both) +} + +/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. +/// ltrim('zzzytest', 'xyz') = 'test' +pub fn ltrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Left) +} + +/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. +/// rtrim('testxxzx', 'xyz') = 'test' +pub fn rtrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Right) +} + /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' pub fn repeat(args: &[ArrayRef]) -> Result { @@ -421,44 +429,6 @@ pub fn replace(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -/// Removes the longest string containing only characters in characters (a space by default) from the end of string. -/// rtrim('testxxzx', 'xyz') = 'test' -pub fn rtrim(args: &[ArrayRef]) -> Result { - match args.len() { - 1 => { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| string.map(|string: &str| string.trim_end_matches(' '))) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => { - let chars: Vec = characters.chars().collect(); - Some(string.trim_end_matches(&chars[..])) - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - other => internal_err!( - "rtrim was called with {other} arguments. It requires at least 1 and at most 2." - ), - } -} - /// Splits string at occurrences of delimiter and returns the n'th field (counting from one). /// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def' pub fn split_part(args: &[ArrayRef]) -> Result { @@ -553,11 +523,149 @@ pub fn uuid(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(array))) } +/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) +/// Replaces a substring of string1 with string2 starting at the integer bit +/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas +/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead +pub fn overlay(args: &[ArrayRef]) -> Result { + match args.len() { + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .map(|((string, characters), start_pos)| { + match (string, characters, start_pos) { + (Some(string), Some(characters), Some(start_pos)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = characters_len as i64; + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + 4 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + let len_num = as_int64_array(&args[3])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .zip(len_num.iter()) + .map(|(((string, characters), start_pos), len)| { + match (string, characters, start_pos, len) { + (Some(string), Some(characters), Some(start_pos), Some(len)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = len.min(string_len as i64); + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + other => { + internal_err!( + "overlay was called with {other} arguments. It requires 3 or 4." + ) + } + } +} + +///Returns the Levenshtein distance between the two given strings. +/// LEVENSHTEIN('kitten', 'sitting') = 3 +pub fn levenshtein(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Internal(format!( + "levenshtein function requires two arguments, got {}", + args.len() + ))); + } + let str1_array = as_generic_string_array::(&args[0])?; + let str2_array = as_generic_string_array::(&args[1])?; + match args[0].data_type() { + DataType::Utf8 => { + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i32) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } + DataType::LargeUtf8 => { + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i64) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } + other => { + internal_err!( + "levenshtein was called with {other} datatype arguments. It requires Utf8 or LargeUtf8." + ) + } + } +} + #[cfg(test)] mod tests { use crate::string_expressions; use arrow::{array::Int32Array, datatypes::Int32Type}; + use arrow_array::Int64Array; + use datafusion_common::cast::as_int32_array; use super::*; @@ -599,4 +707,36 @@ mod tests { Ok(()) } + + #[test] + fn to_overlay() -> Result<()> { + let string = + Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz", "Txxxxas"])); + let replace_string = + Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk", "hom"])); + let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); // start + let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); // replace len + + let res = overlay::(&[string, replace_string, start, end]).unwrap(); + let result = as_generic_string_array::(&res).unwrap(); + let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]); + assert_eq!(&expected, result); + + Ok(()) + } + + #[test] + fn to_levenshtein() -> Result<()> { + let string1_array = + Arc::new(StringArray::from(vec!["123", "abc", "xyz", "kitten"])); + let string2_array = + Arc::new(StringArray::from(vec!["321", "def", "zyx", "sitting"])); + let res = levenshtein::(&[string1_array, string2_array]).unwrap(); + let result = + as_int32_array(&res).expect("failed to initialized function levenshtein"); + let expected = Int32Array::from(vec![2, 3, 2, 3]); + assert_eq!(&expected, result); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/struct_expressions.rs b/datafusion/physical-expr/src/struct_expressions.rs index baa29d668e90..b0ccb2a3ccb6 100644 --- a/datafusion/physical-expr/src/struct_expressions.rs +++ b/datafusion/physical-expr/src/struct_expressions.rs @@ -18,8 +18,8 @@ //! Struct expressions use arrow::array::*; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result}; +use arrow::datatypes::Field; +use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; use std::sync::Arc; @@ -34,31 +34,14 @@ fn array_struct(args: &[ArrayRef]) -> Result { .enumerate() .map(|(i, arg)| { let field_name = format!("c{i}"); - match arg.data_type() { - DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Boolean - | DataType::Float32 - | DataType::Float64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => Ok(( - Arc::new(Field::new( - field_name.as_str(), - arg.data_type().clone(), - true, - )), - arg.clone(), + Ok(( + Arc::new(Field::new( + field_name.as_str(), + arg.data_type().clone(), + true, )), - data_type => { - not_impl_err!("Struct is not implemented for type '{data_type:?}'.") - } - } + arg.clone(), + )) }) .collect::>>()?; @@ -67,13 +50,15 @@ fn array_struct(args: &[ArrayRef]) -> Result { /// put values in a struct array. pub fn struct_expr(values: &[ColumnarValue]) -> Result { - let arrays: Vec = values + let arrays = values .iter() - .map(|x| match x { - ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), + .map(|x| { + Ok(match x { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => scalar.to_array()?.clone(), + }) }) - .collect(); + .collect::>>()?; Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?)) } @@ -93,7 +78,8 @@ mod tests { ]; let struc = struct_expr(&args) .expect("failed to initialize function struct") - .into_array(1); + .into_array(1) + .expect("Failed to convert to array"); let result = as_struct_array(&struc).expect("failed to initialize function struct"); assert_eq!( diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index af1e77cbf566..9daa9eb173dd 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -35,10 +35,10 @@ pub fn create_physical_expr( .collect::>>()?; Ok(Arc::new(ScalarFunctionExpr::new( - &fun.name, - fun.fun.clone(), + fun.name(), + fun.fun(), input_phy_exprs.to_vec(), - (fun.return_type)(&input_exprs_types)?.as_ref(), + fun.return_type(&input_exprs_types)?, None, ))) } diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index e28700a25ce4..240efe4223c3 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -455,3 +455,107 @@ pub fn translate(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } + +/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www +/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache +/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org +/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org +pub fn substr_index(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return internal_err!( + "substr_index was called with {} arguments. It requires 3.", + args.len() + ); + } + + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + let count_array = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(delimiter_array.iter()) + .zip(count_array.iter()) + .map(|((string, delimiter), n)| match (string, delimiter, n) { + (Some(string), Some(delimiter), Some(n)) => { + let mut res = String::new(); + match n { + 0 => { + "".to_string(); + } + _other => { + if n > 0 { + let idx = string + .split(delimiter) + .take(n as usize) + .fold(0, |len, x| len + x.len() + delimiter.len()) + - delimiter.len(); + res.push_str(if idx >= string.len() { + string + } else { + &string[..idx] + }); + } else { + let idx = (string.split(delimiter).take((-n) as usize).fold( + string.len() as isize, + |len, x| { + len - x.len() as isize - delimiter.len() as isize + }, + ) + delimiter.len() as isize) + as usize; + res.push_str(if idx >= string.len() { + string + } else { + &string[idx..] + }); + } + } + } + Some(res) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings +///A string list is a string composed of substrings separated by , characters. +pub fn find_in_set(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + if args.len() != 2 { + return internal_err!( + "find_in_set was called with {} arguments. It requires 2.", + args.len() + ); + } + + let str_array: &GenericStringArray = + as_generic_string_array::(&args[0])?; + let str_list_array: &GenericStringArray = + as_generic_string_array::(&args[1])?; + + let result = str_array + .iter() + .zip(str_list_array.iter()) + .map(|(string, str_list)| match (string, str_list) { + (Some(string), Some(str_list)) => { + let mut res = 0; + let str_set: Vec<&str> = str_list.split(',').collect(); + for (idx, str) in str_set.iter().enumerate() { + if str == &string { + res = idx + 1; + break; + } + } + T::Native::from_usize(res) + } + _ => None, + }) + .collect::>(); + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs deleted file mode 100644 index b2a6bb5ca6d2..000000000000 --- a/datafusion/physical-expr/src/utils.rs +++ /dev/null @@ -1,1839 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::equivalence::{EquivalenceProperties, OrderingEquivalenceProperties}; -use crate::expressions::{BinaryExpr, Column, UnKnownColumn}; -use crate::sort_properties::{ExprOrdering, SortProperties}; -use crate::update_ordering; -use crate::{PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement}; - -use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; -use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; -use arrow::datatypes::SchemaRef; -use arrow_schema::SortOptions; -use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRewriter, VisitRecursion, -}; -use datafusion_common::utils::longest_consecutive_prefix; -use datafusion_common::Result; -use datafusion_expr::Operator; - -use itertools::Itertools; -use petgraph::graph::NodeIndex; -use petgraph::stable_graph::StableGraph; -use std::borrow::Borrow; -use std::collections::HashMap; -use std::collections::HashSet; -use std::sync::Arc; - -/// Compare the two expr lists are equal no matter the order. -/// For example two InListExpr can be considered to be equals no matter the order: -/// -/// In('a','b','c') == In('c','b','a') -pub fn expr_list_eq_any_order( - list1: &[Arc], - list2: &[Arc], -) -> bool { - if list1.len() == list2.len() { - let mut expr_vec1 = list1.to_vec(); - let mut expr_vec2 = list2.to_vec(); - while let Some(expr1) = expr_vec1.pop() { - if let Some(idx) = expr_vec2.iter().position(|expr2| expr1.eq(expr2)) { - expr_vec2.swap_remove(idx); - } else { - break; - } - } - expr_vec1.is_empty() && expr_vec2.is_empty() - } else { - false - } -} - -/// Strictly compare the two expr lists are equal in the given order. -pub fn expr_list_eq_strict_order( - list1: &[Arc], - list2: &[Arc], -) -> bool { - list1.len() == list2.len() && list1.iter().zip(list2.iter()).all(|(e1, e2)| e1.eq(e2)) -} - -/// Assume the predicate is in the form of CNF, split the predicate to a Vec of PhysicalExprs. -/// -/// For example, split "a1 = a2 AND b1 <= b2 AND c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"] -pub fn split_conjunction( - predicate: &Arc, -) -> Vec<&Arc> { - split_conjunction_impl(predicate, vec![]) -} - -fn split_conjunction_impl<'a>( - predicate: &'a Arc, - mut exprs: Vec<&'a Arc>, -) -> Vec<&'a Arc> { - match predicate.as_any().downcast_ref::() { - Some(binary) => match binary.op() { - Operator::And => { - let exprs = split_conjunction_impl(binary.left(), exprs); - split_conjunction_impl(binary.right(), exprs) - } - _ => { - exprs.push(predicate); - exprs - } - }, - None => { - exprs.push(predicate); - exprs - } - } -} - -/// Normalize the output expressions based on Columns Map. -/// -/// If there is a mapping in Columns Map, replace the Column in the output expressions with the 1st Column in the Columns Map. -/// Otherwise, replace the Column with a place holder of [UnKnownColumn] -/// -pub fn normalize_out_expr_with_columns_map( - expr: Arc, - columns_map: &HashMap>, -) -> Arc { - expr.clone() - .transform(&|expr| { - let normalized_form = match expr.as_any().downcast_ref::() { - Some(column) => columns_map - .get(column) - .map(|c| Arc::new(c[0].clone()) as _) - .or_else(|| Some(Arc::new(UnKnownColumn::new(column.name())) as _)), - None => None, - }; - Ok(if let Some(normalized_form) = normalized_form { - Transformed::Yes(normalized_form) - } else { - Transformed::No(expr) - }) - }) - .unwrap_or(expr) -} - -/// Transform `sort_exprs` vector, to standardized version using `eq_properties` and `ordering_eq_properties` -/// Assume `eq_properties` states that `Column a` and `Column b` are aliases. -/// Also assume `ordering_eq_properties` states that ordering `vec![d ASC]` and `vec![a ASC, c ASC]` are -/// ordering equivalent (in the sense that both describe the ordering of the table). -/// If the `sort_exprs` input to this function were `vec![b ASC, c ASC]`, -/// This function converts `sort_exprs` `vec![b ASC, c ASC]` to first `vec![a ASC, c ASC]` after considering `eq_properties` -/// Then converts `vec![a ASC, c ASC]` to `vec![d ASC]` after considering `ordering_eq_properties`. -/// Standardized version `vec![d ASC]` is used in subsequent operations. -fn normalize_sort_exprs( - sort_exprs: &[PhysicalSortExpr], - eq_properties: &EquivalenceProperties, - ordering_eq_properties: &OrderingEquivalenceProperties, -) -> Vec { - let sort_requirements = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); - let normalized_exprs = normalize_sort_requirements( - &sort_requirements, - eq_properties, - ordering_eq_properties, - ); - PhysicalSortRequirement::to_sort_exprs(normalized_exprs) -} - -/// Transform `sort_reqs` vector, to standardized version using `eq_properties` and `ordering_eq_properties` -/// Assume `eq_properties` states that `Column a` and `Column b` are aliases. -/// Also assume `ordering_eq_properties` states that ordering `vec![d ASC]` and `vec![a ASC, c ASC]` are -/// ordering equivalent (in the sense that both describe the ordering of the table). -/// If the `sort_reqs` input to this function were `vec![b Some(ASC), c None]`, -/// This function converts `sort_exprs` `vec![b Some(ASC), c None]` to first `vec![a Some(ASC), c None]` after considering `eq_properties` -/// Then converts `vec![a Some(ASC), c None]` to `vec![d Some(ASC)]` after considering `ordering_eq_properties`. -/// Standardized version `vec![d Some(ASC)]` is used in subsequent operations. -fn normalize_sort_requirements( - sort_reqs: &[PhysicalSortRequirement], - eq_properties: &EquivalenceProperties, - ordering_eq_properties: &OrderingEquivalenceProperties, -) -> Vec { - let normalized_sort_reqs = eq_properties.normalize_sort_requirements(sort_reqs); - ordering_eq_properties.normalize_sort_requirements(&normalized_sort_reqs) -} - -/// Checks whether given ordering requirements are satisfied by provided [PhysicalSortExpr]s. -pub fn ordering_satisfy< - F: FnOnce() -> EquivalenceProperties, - F2: FnOnce() -> OrderingEquivalenceProperties, ->( - provided: Option<&[PhysicalSortExpr]>, - required: Option<&[PhysicalSortExpr]>, - equal_properties: F, - ordering_equal_properties: F2, -) -> bool { - match (provided, required) { - (_, None) => true, - (None, Some(_)) => false, - (Some(provided), Some(required)) => ordering_satisfy_concrete( - provided, - required, - equal_properties, - ordering_equal_properties, - ), - } -} - -/// Checks whether the required [`PhysicalSortExpr`]s are satisfied by the -/// provided [`PhysicalSortExpr`]s. -pub fn ordering_satisfy_concrete< - F: FnOnce() -> EquivalenceProperties, - F2: FnOnce() -> OrderingEquivalenceProperties, ->( - provided: &[PhysicalSortExpr], - required: &[PhysicalSortExpr], - equal_properties: F, - ordering_equal_properties: F2, -) -> bool { - let oeq_properties = ordering_equal_properties(); - let eq_properties = equal_properties(); - let required_normalized = - normalize_sort_exprs(required, &eq_properties, &oeq_properties); - let provided_normalized = - normalize_sort_exprs(provided, &eq_properties, &oeq_properties); - if required_normalized.len() > provided_normalized.len() { - return false; - } - required_normalized - .into_iter() - .zip(provided_normalized) - .all(|(req, given)| given == req) -} - -/// Checks whether the given [`PhysicalSortRequirement`]s are satisfied by the -/// provided [`PhysicalSortExpr`]s. -pub fn ordering_satisfy_requirement< - F: FnOnce() -> EquivalenceProperties, - F2: FnOnce() -> OrderingEquivalenceProperties, ->( - provided: Option<&[PhysicalSortExpr]>, - required: Option<&[PhysicalSortRequirement]>, - equal_properties: F, - ordering_equal_properties: F2, -) -> bool { - match (provided, required) { - (_, None) => true, - (None, Some(_)) => false, - (Some(provided), Some(required)) => ordering_satisfy_requirement_concrete( - provided, - required, - equal_properties, - ordering_equal_properties, - ), - } -} - -/// Checks whether the given [`PhysicalSortRequirement`]s are satisfied by the -/// provided [`PhysicalSortExpr`]s. -pub fn ordering_satisfy_requirement_concrete< - F: FnOnce() -> EquivalenceProperties, - F2: FnOnce() -> OrderingEquivalenceProperties, ->( - provided: &[PhysicalSortExpr], - required: &[PhysicalSortRequirement], - equal_properties: F, - ordering_equal_properties: F2, -) -> bool { - let oeq_properties = ordering_equal_properties(); - let eq_properties = equal_properties(); - let required_normalized = - normalize_sort_requirements(required, &eq_properties, &oeq_properties); - let provided_normalized = - normalize_sort_exprs(provided, &eq_properties, &oeq_properties); - if required_normalized.len() > provided_normalized.len() { - return false; - } - required_normalized - .into_iter() - .zip(provided_normalized) - .all(|(req, given)| given.satisfy(&req)) -} - -/// Checks whether the given [`PhysicalSortRequirement`]s are equal or more -/// specific than the provided [`PhysicalSortRequirement`]s. -pub fn requirements_compatible< - F: FnOnce() -> OrderingEquivalenceProperties, - F2: FnOnce() -> EquivalenceProperties, ->( - provided: Option<&[PhysicalSortRequirement]>, - required: Option<&[PhysicalSortRequirement]>, - ordering_equal_properties: F, - equal_properties: F2, -) -> bool { - match (provided, required) { - (_, None) => true, - (None, Some(_)) => false, - (Some(provided), Some(required)) => requirements_compatible_concrete( - provided, - required, - ordering_equal_properties, - equal_properties, - ), - } -} - -/// Checks whether the given [`PhysicalSortRequirement`]s are equal or more -/// specific than the provided [`PhysicalSortRequirement`]s. -fn requirements_compatible_concrete< - F: FnOnce() -> OrderingEquivalenceProperties, - F2: FnOnce() -> EquivalenceProperties, ->( - provided: &[PhysicalSortRequirement], - required: &[PhysicalSortRequirement], - ordering_equal_properties: F, - equal_properties: F2, -) -> bool { - let oeq_properties = ordering_equal_properties(); - let eq_properties = equal_properties(); - - let required_normalized = - normalize_sort_requirements(required, &eq_properties, &oeq_properties); - let provided_normalized = - normalize_sort_requirements(provided, &eq_properties, &oeq_properties); - if required_normalized.len() > provided_normalized.len() { - return false; - } - required_normalized - .into_iter() - .zip(provided_normalized) - .all(|(req, given)| given.compatible(&req)) -} - -/// This function maps back requirement after ProjectionExec -/// to the Executor for its input. -// Specifically, `ProjectionExec` changes index of `Column`s in the schema of its input executor. -// This function changes requirement given according to ProjectionExec schema to the requirement -// according to schema of input executor to the ProjectionExec. -// For instance, Column{"a", 0} would turn to Column{"a", 1}. Please note that this function assumes that -// name of the Column is unique. If we have a requirement such that Column{"a", 0}, Column{"a", 1}. -// This function will produce incorrect result (It will only emit single Column as a result). -pub fn map_columns_before_projection( - parent_required: &[Arc], - proj_exprs: &[(Arc, String)], -) -> Vec> { - let column_mapping = proj_exprs - .iter() - .filter_map(|(expr, name)| { - expr.as_any() - .downcast_ref::() - .map(|column| (name.clone(), column.clone())) - }) - .collect::>(); - parent_required - .iter() - .filter_map(|r| { - r.as_any() - .downcast_ref::() - .and_then(|c| column_mapping.get(c.name())) - }) - .map(|e| Arc::new(e.clone()) as _) - .collect() -} - -/// This function returns all `Arc`s inside the given -/// `PhysicalSortExpr` sequence. -pub fn convert_to_expr>( - sequence: impl IntoIterator, -) -> Vec> { - sequence - .into_iter() - .map(|elem| elem.borrow().expr.clone()) - .collect() -} - -/// This function finds the indices of `targets` within `items`, taking into -/// account equivalences according to `equal_properties`. -pub fn get_indices_of_matching_exprs EquivalenceProperties>( - targets: &[Arc], - items: &[Arc], - equal_properties: F, -) -> Vec { - let eq_properties = equal_properties(); - let normalized_items = eq_properties.normalize_exprs(items); - let normalized_targets = eq_properties.normalize_exprs(targets); - get_indices_of_exprs_strict(normalized_targets, &normalized_items) -} - -/// This function finds the indices of `targets` within `items` using strict -/// equality. -pub fn get_indices_of_exprs_strict>>( - targets: impl IntoIterator, - items: &[Arc], -) -> Vec { - targets - .into_iter() - .filter_map(|target| items.iter().position(|e| e.eq(target.borrow()))) - .collect() -} - -#[derive(Clone, Debug)] -pub struct ExprTreeNode { - expr: Arc, - data: Option, - child_nodes: Vec>, -} - -impl ExprTreeNode { - pub fn new(expr: Arc) -> Self { - ExprTreeNode { - expr, - data: None, - child_nodes: vec![], - } - } - - pub fn expression(&self) -> &Arc { - &self.expr - } - - pub fn children(&self) -> Vec> { - self.expr - .children() - .into_iter() - .map(ExprTreeNode::new) - .collect() - } -} - -impl TreeNode for ExprTreeNode { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) - } - - fn map_children(mut self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { - self.child_nodes = self - .children() - .into_iter() - .map(transform) - .collect::>>()?; - Ok(self) - } -} - -/// This struct facilitates the [TreeNodeRewriter] mechanism to convert a -/// [PhysicalExpr] tree into a DAEG (i.e. an expression DAG) by collecting -/// identical expressions in one node. Caller specifies the node type in the -/// DAEG via the `constructor` argument, which constructs nodes in the DAEG -/// from the [ExprTreeNode] ancillary object. -struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> T> { - // The resulting DAEG (expression DAG). - graph: StableGraph, - // A vector of visited expression nodes and their corresponding node indices. - visited_plans: Vec<(Arc, NodeIndex)>, - // A function to convert an input expression node to T. - constructor: &'a F, -} - -impl<'a, T, F: Fn(&ExprTreeNode) -> T> TreeNodeRewriter - for PhysicalExprDAEGBuilder<'a, T, F> -{ - type N = ExprTreeNode; - // This method mutates an expression node by transforming it to a physical expression - // and adding it to the graph. The method returns the mutated expression node. - fn mutate( - &mut self, - mut node: ExprTreeNode, - ) -> Result> { - // Get the expression associated with the input expression node. - let expr = &node.expr; - - // Check if the expression has already been visited. - let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) { - // If the expression has been visited, return the corresponding node index. - Some((_, idx)) => *idx, - // If the expression has not been visited, add a new node to the graph and - // add edges to its child nodes. Add the visited expression to the vector - // of visited expressions and return the newly created node index. - None => { - let node_idx = self.graph.add_node((self.constructor)(&node)); - for expr_node in node.child_nodes.iter() { - self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0); - } - self.visited_plans.push((expr.clone(), node_idx)); - node_idx - } - }; - // Set the data field of the input expression node to the corresponding node index. - node.data = Some(node_idx); - // Return the mutated expression node. - Ok(node) - } -} - -// A function that builds a directed acyclic graph of physical expression trees. -pub fn build_dag( - expr: Arc, - constructor: &F, -) -> Result<(NodeIndex, StableGraph)> -where - F: Fn(&ExprTreeNode) -> T, -{ - // Create a new expression tree node from the input expression. - let init = ExprTreeNode::new(expr); - // Create a new `PhysicalExprDAEGBuilder` instance. - let mut builder = PhysicalExprDAEGBuilder { - graph: StableGraph::::new(), - visited_plans: Vec::<(Arc, NodeIndex)>::new(), - constructor, - }; - // Use the builder to transform the expression tree node into a DAG. - let root = init.rewrite(&mut builder)?; - // Return a tuple containing the root node index and the DAG. - Ok((root.data.unwrap(), builder.graph)) -} - -/// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`]. -pub fn collect_columns(expr: &Arc) -> HashSet { - let mut columns = HashSet::::new(); - expr.apply(&mut |expr| { - if let Some(column) = expr.as_any().downcast_ref::() { - if !columns.iter().any(|c| c.eq(column)) { - columns.insert(column.clone()); - } - } - Ok(VisitRecursion::Continue) - }) - // pre_visit always returns OK, so this will always too - .expect("no way to return error during recursion"); - columns -} - -/// Re-assign column indices referenced in predicate according to given schema. -/// This may be helpful when dealing with projections. -pub fn reassign_predicate_columns( - pred: Arc, - schema: &SchemaRef, - ignore_not_found: bool, -) -> Result> { - pred.transform_down(&|expr| { - let expr_any = expr.as_any(); - - if let Some(column) = expr_any.downcast_ref::() { - let index = match schema.index_of(column.name()) { - Ok(idx) => idx, - Err(_) if ignore_not_found => usize::MAX, - Err(e) => return Err(e.into()), - }; - return Ok(Transformed::Yes(Arc::new(Column::new( - column.name(), - index, - )))); - } - Ok(Transformed::No(expr)) - }) -} - -/// Reverses the ORDER BY expression, which is useful during equivalent window -/// expression construction. For instance, 'ORDER BY a ASC, NULLS LAST' turns into -/// 'ORDER BY a DESC, NULLS FIRST'. -pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec { - order_bys - .iter() - .map(|e| PhysicalSortExpr { - expr: e.expr.clone(), - options: !e.options, - }) - .collect() -} - -/// Find the finer requirement among `req1` and `req2` -/// If `None`, this means that `req1` and `req2` are not compatible -/// e.g there is no requirement that satisfies both -pub fn get_finer_ordering< - 'a, - F: Fn() -> EquivalenceProperties, - F2: Fn() -> OrderingEquivalenceProperties, ->( - req1: &'a [PhysicalSortExpr], - req2: &'a [PhysicalSortExpr], - eq_properties: F, - ordering_eq_properties: F2, -) -> Option<&'a [PhysicalSortExpr]> { - if ordering_satisfy_concrete(req1, req2, &eq_properties, &ordering_eq_properties) { - // Finer requirement is `provided`, since it satisfies the other: - return Some(req1); - } - if ordering_satisfy_concrete(req2, req1, &eq_properties, &ordering_eq_properties) { - // Finer requirement is `req`, since it satisfies the other: - return Some(req2); - } - // Neither `provided` nor `req` satisfies one another, they are incompatible. - None -} - -/// Scatter `truthy` array by boolean mask. When the mask evaluates `true`, next values of `truthy` -/// are taken, when the mask evaluates `false` values null values are filled. -/// -/// # Arguments -/// * `mask` - Boolean values used to determine where to put the `truthy` values -/// * `truthy` - All values of this array are to scatter according to `mask` into final result. -pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { - let truthy = truthy.to_data(); - - // update the mask so that any null values become false - // (SlicesIterator doesn't respect nulls) - let mask = and_kleene(mask, &is_not_null(mask)?)?; - - let mut mutable = MutableArrayData::new(vec![&truthy], true, mask.len()); - - // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to - // fill with falsy values - - // keep track of how much is filled - let mut filled = 0; - // keep track of current position we have in truthy array - let mut true_pos = 0; - - SlicesIterator::new(&mask).for_each(|(start, end)| { - // the gap needs to be filled with nulls - if start > filled { - mutable.extend_nulls(start - filled); - } - // fill with truthy values - let len = end - start; - mutable.extend(0, true_pos, true_pos + len); - true_pos += len; - filled = end; - }); - // the remaining part is falsy - if filled < mask.len() { - mutable.extend_nulls(mask.len() - filled); - } - - let data = mutable.freeze(); - Ok(make_array(data)) -} - -/// Return indices of each item in `required_exprs` inside `provided_exprs`. -/// All the items should be found inside `provided_exprs`. Found indices will -/// be a permutation of the range 0, 1, ..., N. For example, \[2,1,0\] is valid -/// (\[0,1,2\] is consecutive), but \[3,1,0\] is not valid (\[0,1,3\] is not -/// consecutive). -fn get_lexicographical_match_indices( - required_exprs: &[Arc], - provided_exprs: &[Arc], -) -> Option> { - let indices_of_equality = get_indices_of_exprs_strict(required_exprs, provided_exprs); - let mut ordered_indices = indices_of_equality.clone(); - ordered_indices.sort(); - let n_match = indices_of_equality.len(); - let first_n = longest_consecutive_prefix(ordered_indices); - (n_match == required_exprs.len() && first_n == n_match && n_match > 0) - .then_some(indices_of_equality) -} - -/// Attempts to find a full match between the required columns to be ordered (lexicographically), and -/// the provided sort options (lexicographically), while considering equivalence properties. -/// -/// It starts by normalizing members of both the required columns and the provided sort options. -/// If a full match is found, returns the sort options and indices of the matches. If no full match is found, -/// the function proceeds to check against ordering equivalence properties. If still no full match is found, -/// the function returns `None`. -pub fn get_indices_of_matching_sort_exprs_with_order_eq( - provided_sorts: &[PhysicalSortExpr], - required_columns: &[Column], - eq_properties: &EquivalenceProperties, - order_eq_properties: &OrderingEquivalenceProperties, -) -> Option<(Vec, Vec)> { - // Create a vector of `PhysicalSortRequirement`s from the required columns: - let sort_requirement_on_requirements = required_columns - .iter() - .map(|required_column| PhysicalSortRequirement { - expr: Arc::new(required_column.clone()) as _, - options: None, - }) - .collect::>(); - - let normalized_required = normalize_sort_requirements( - &sort_requirement_on_requirements, - eq_properties, - &OrderingEquivalenceProperties::new(order_eq_properties.schema()), - ); - let normalized_provided = normalize_sort_requirements( - &PhysicalSortRequirement::from_sort_exprs(provided_sorts.iter()), - eq_properties, - &OrderingEquivalenceProperties::new(order_eq_properties.schema()), - ); - - let provided_sorts = normalized_provided - .iter() - .map(|req| req.expr.clone()) - .collect::>(); - - let normalized_required_expr = normalized_required - .iter() - .map(|req| req.expr.clone()) - .collect::>(); - - if let Some(indices_of_equality) = - get_lexicographical_match_indices(&normalized_required_expr, &provided_sorts) - { - return Some(( - indices_of_equality - .iter() - .filter_map(|index| normalized_provided[*index].options) - .collect(), - indices_of_equality, - )); - } - - // We did not find all the expressions, consult ordering equivalence properties: - if let Some(oeq_class) = order_eq_properties.oeq_class() { - let head = oeq_class.head(); - for ordering in oeq_class.others().iter().chain(std::iter::once(head)) { - let order_eq_class_exprs = convert_to_expr(ordering); - if let Some(indices_of_equality) = get_lexicographical_match_indices( - &normalized_required_expr, - &order_eq_class_exprs, - ) { - return Some(( - indices_of_equality - .iter() - .map(|index| ordering[*index].options) - .collect(), - indices_of_equality, - )); - } - } - } - // If no match found, return `None`: - None -} - -/// Calculates the output orderings for a set of expressions within the context of a given -/// execution plan. The resulting orderings are all in the type of [`Column`], since these -/// expressions become [`Column`] after the projection step. The expressions having an alias -/// are renamed with those aliases in the returned [`PhysicalSortExpr`]'s. If an expression -/// is found to be unordered, the corresponding entry in the output vector is `None`. -/// -/// # Arguments -/// -/// * `expr` - A slice of tuples containing expressions and their corresponding aliases. -/// -/// * `input_output_ordering` - Output ordering of the input plan. -/// -/// * `input_equal_properties` - Equivalence properties of the columns in the input plan. -/// -/// * `input_ordering_equal_properties` - Ordering equivalence properties of the columns in the input plan. -/// -/// # Returns -/// -/// A `Result` containing a vector of optional [`PhysicalSortExpr`]'s. Each element of the -/// vector corresponds to an expression from the input slice. If an expression can be ordered, -/// the corresponding entry is `Some(PhysicalSortExpr)`. If an expression cannot be ordered, -/// the entry is `None`. -pub fn find_orderings_of_exprs( - expr: &[(Arc, String)], - input_output_ordering: Option<&[PhysicalSortExpr]>, - input_equal_properties: EquivalenceProperties, - input_ordering_equal_properties: OrderingEquivalenceProperties, -) -> Result>> { - let mut orderings: Vec> = vec![]; - if let Some(leading_ordering) = - input_output_ordering.and_then(|output_ordering| output_ordering.first()) - { - for (index, (expression, name)) in expr.iter().enumerate() { - let initial_expr = ExprOrdering::new(expression.clone()); - let transformed = initial_expr.transform_up(&|expr| { - update_ordering( - expr, - leading_ordering, - &input_equal_properties, - &input_ordering_equal_properties, - ) - })?; - if let Some(SortProperties::Ordered(sort_options)) = transformed.state { - orderings.push(Some(PhysicalSortExpr { - expr: Arc::new(Column::new(name, index)), - options: sort_options, - })); - } else { - orderings.push(None); - } - } - } else { - orderings.extend(expr.iter().map(|_| None)); - } - Ok(orderings) -} - -/// Merge left and right sort expressions, checking for duplicates. -pub fn merge_vectors( - left: &[PhysicalSortExpr], - right: &[PhysicalSortExpr], -) -> Vec { - left.iter() - .cloned() - .chain(right.iter().cloned()) - .unique() - .collect() -} - -#[cfg(test)] -mod tests { - use std::fmt::{Display, Formatter}; - use std::ops::Not; - use std::sync::Arc; - - use super::*; - use crate::equivalence::OrderingEquivalenceProperties; - use crate::expressions::{binary, cast, col, in_list, lit, Column, Literal}; - use crate::{OrderingEquivalentClass, PhysicalSortExpr}; - - use arrow::compute::SortOptions; - use arrow_array::Int32Array; - use arrow_schema::{DataType, Field, Schema}; - use datafusion_common::cast::{as_boolean_array, as_int32_array}; - use datafusion_common::{Result, ScalarValue}; - - use petgraph::visit::Bfs; - - #[derive(Clone)] - struct DummyProperty { - expr_type: String, - } - - /// This is a dummy node in the DAEG; it stores a reference to the actual - /// [PhysicalExpr] as well as a dummy property. - #[derive(Clone)] - struct PhysicalExprDummyNode { - pub expr: Arc, - pub property: DummyProperty, - } - - impl Display for PhysicalExprDummyNode { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.expr) - } - } - - fn make_dummy_node(node: &ExprTreeNode) -> PhysicalExprDummyNode { - let expr = node.expression().clone(); - let dummy_property = if expr.as_any().is::() { - "Binary" - } else if expr.as_any().is::() { - "Column" - } else if expr.as_any().is::() { - "Literal" - } else { - "Other" - } - .to_owned(); - PhysicalExprDummyNode { - expr, - property: DummyProperty { - expr_type: dummy_property, - }, - } - } - - // Generate a schema which consists of 5 columns (a, b, c, d, e) - fn create_test_schema() -> Result { - let a = Field::new("a", DataType::Int32, true); - let b = Field::new("b", DataType::Int32, true); - let c = Field::new("c", DataType::Int32, true); - let d = Field::new("d", DataType::Int32, true); - let e = Field::new("e", DataType::Int32, true); - let f = Field::new("f", DataType::Int32, true); - let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); - - Ok(schema) - } - - fn create_test_params() -> Result<( - SchemaRef, - EquivalenceProperties, - OrderingEquivalenceProperties, - )> { - // Assume schema satisfies ordering a ASC NULLS LAST - // and d ASC NULLS LAST, b ASC NULLS LAST and e DESC NULLS FIRST, f ASC NULLS LAST, g ASC NULLS LAST - // Assume that column a and c are aliases. - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); - let col_d = &Column::new("d", 3); - let col_e = &Column::new("e", 4); - let col_f = &Column::new("f", 5); - let col_g = &Column::new("g", 6); - let option1 = SortOptions { - descending: false, - nulls_first: false, - }; - let option2 = SortOptions { - descending: true, - nulls_first: true, - }; - let test_schema = create_test_schema()?; - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - eq_properties.add_equal_conditions((col_a, col_c)); - let mut ordering_eq_properties = - OrderingEquivalenceProperties::new(test_schema.clone()); - ordering_eq_properties.add_equal_conditions(( - &vec![PhysicalSortExpr { - expr: Arc::new(col_a.clone()), - options: option1, - }], - &vec![ - PhysicalSortExpr { - expr: Arc::new(col_d.clone()), - options: option1, - }, - PhysicalSortExpr { - expr: Arc::new(col_b.clone()), - options: option1, - }, - ], - )); - ordering_eq_properties.add_equal_conditions(( - &vec![PhysicalSortExpr { - expr: Arc::new(col_a.clone()), - options: option1, - }], - &vec![ - PhysicalSortExpr { - expr: Arc::new(col_e.clone()), - options: option2, - }, - PhysicalSortExpr { - expr: Arc::new(col_f.clone()), - options: option1, - }, - PhysicalSortExpr { - expr: Arc::new(col_g.clone()), - options: option1, - }, - ], - )); - Ok((test_schema, eq_properties, ordering_eq_properties)) - } - - #[test] - fn test_build_dag() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("0", DataType::Int32, true), - Field::new("1", DataType::Int32, true), - Field::new("2", DataType::Int32, true), - ]); - let expr = binary( - cast( - binary( - col("0", &schema)?, - Operator::Plus, - col("1", &schema)?, - &schema, - )?, - &schema, - DataType::Int64, - )?, - Operator::Gt, - binary( - cast(col("2", &schema)?, &schema, DataType::Int64)?, - Operator::Plus, - lit(ScalarValue::Int64(Some(10))), - &schema, - )?, - &schema, - )?; - let mut vector_dummy_props = vec![]; - let (root, graph) = build_dag(expr, &make_dummy_node)?; - let mut bfs = Bfs::new(&graph, root); - while let Some(node_index) = bfs.next(&graph) { - let node = &graph[node_index]; - vector_dummy_props.push(node.property.clone()); - } - - assert_eq!( - vector_dummy_props - .iter() - .filter(|property| property.expr_type == "Binary") - .count(), - 3 - ); - assert_eq!( - vector_dummy_props - .iter() - .filter(|property| property.expr_type == "Column") - .count(), - 3 - ); - assert_eq!( - vector_dummy_props - .iter() - .filter(|property| property.expr_type == "Literal") - .count(), - 1 - ); - assert_eq!( - vector_dummy_props - .iter() - .filter(|property| property.expr_type == "Other") - .count(), - 2 - ); - Ok(()) - } - - #[test] - fn test_convert_to_expr() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::UInt64, false)]); - let sort_expr = vec![PhysicalSortExpr { - expr: col("a", &schema)?, - options: Default::default(), - }]; - assert!(convert_to_expr(&sort_expr)[0].eq(&sort_expr[0].expr)); - Ok(()) - } - - #[test] - fn test_get_indices_of_matching_exprs() { - let empty_schema = &Arc::new(Schema::empty()); - let equal_properties = || EquivalenceProperties::new(empty_schema.clone()); - let list1: Vec> = vec![ - Arc::new(Column::new("a", 0)), - Arc::new(Column::new("b", 1)), - Arc::new(Column::new("c", 2)), - Arc::new(Column::new("d", 3)), - ]; - let list2: Vec> = vec![ - Arc::new(Column::new("b", 1)), - Arc::new(Column::new("c", 2)), - Arc::new(Column::new("a", 0)), - ]; - assert_eq!( - get_indices_of_matching_exprs(&list1, &list2, equal_properties), - vec![2, 0, 1] - ); - assert_eq!( - get_indices_of_matching_exprs(&list2, &list1, equal_properties), - vec![1, 2, 0] - ); - } - - #[test] - fn expr_list_eq_test() -> Result<()> { - let list1: Vec> = vec![ - Arc::new(Column::new("a", 0)), - Arc::new(Column::new("a", 0)), - Arc::new(Column::new("b", 1)), - ]; - let list2: Vec> = vec![ - Arc::new(Column::new("b", 1)), - Arc::new(Column::new("b", 1)), - Arc::new(Column::new("a", 0)), - ]; - assert!(!expr_list_eq_any_order(list1.as_slice(), list2.as_slice())); - assert!(!expr_list_eq_any_order(list2.as_slice(), list1.as_slice())); - - assert!(!expr_list_eq_strict_order( - list1.as_slice(), - list2.as_slice() - )); - assert!(!expr_list_eq_strict_order( - list2.as_slice(), - list1.as_slice() - )); - - let list3: Vec> = vec![ - Arc::new(Column::new("a", 0)), - Arc::new(Column::new("b", 1)), - Arc::new(Column::new("c", 2)), - Arc::new(Column::new("a", 0)), - Arc::new(Column::new("b", 1)), - ]; - let list4: Vec> = vec![ - Arc::new(Column::new("b", 1)), - Arc::new(Column::new("b", 1)), - Arc::new(Column::new("a", 0)), - Arc::new(Column::new("c", 2)), - Arc::new(Column::new("a", 0)), - ]; - assert!(expr_list_eq_any_order(list3.as_slice(), list4.as_slice())); - assert!(expr_list_eq_any_order(list4.as_slice(), list3.as_slice())); - assert!(expr_list_eq_any_order(list3.as_slice(), list3.as_slice())); - assert!(expr_list_eq_any_order(list4.as_slice(), list4.as_slice())); - - assert!(!expr_list_eq_strict_order( - list3.as_slice(), - list4.as_slice() - )); - assert!(!expr_list_eq_strict_order( - list4.as_slice(), - list3.as_slice() - )); - assert!(expr_list_eq_any_order(list3.as_slice(), list3.as_slice())); - assert!(expr_list_eq_any_order(list4.as_slice(), list4.as_slice())); - - Ok(()) - } - - #[test] - fn test_ordering_satisfy() -> Result<()> { - let crude = vec![PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }]; - let crude = Some(&crude[..]); - let finer = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - ]; - let finer = Some(&finer[..]); - let empty_schema = &Arc::new(Schema::empty()); - assert!(ordering_satisfy( - finer, - crude, - || { EquivalenceProperties::new(empty_schema.clone()) }, - || { OrderingEquivalenceProperties::new(empty_schema.clone()) }, - )); - assert!(!ordering_satisfy( - crude, - finer, - || { EquivalenceProperties::new(empty_schema.clone()) }, - || { OrderingEquivalenceProperties::new(empty_schema.clone()) }, - )); - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); - let col_d = &Column::new("d", 3); - let col_e = &Column::new("e", 4); - let col_f = &Column::new("f", 5); - let col_g = &Column::new("g", 6); - let option1 = SortOptions { - descending: false, - nulls_first: false, - }; - let option2 = SortOptions { - descending: true, - nulls_first: true, - }; - // The schema is ordered by a ASC NULLS LAST, b ASC NULLS LAST - let provided = vec![ - PhysicalSortExpr { - expr: Arc::new(col_a.clone()), - options: option1, - }, - PhysicalSortExpr { - expr: Arc::new(col_b.clone()), - options: option1, - }, - ]; - let provided = Some(&provided[..]); - let (_test_schema, eq_properties, ordering_eq_properties) = create_test_params()?; - // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function - let requirements = vec![ - // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it - (vec![(col_a, option1)], true), - (vec![(col_a, option2)], false), - // Test whether equivalence works as expected - (vec![(col_c, option1)], true), - (vec![(col_c, option2)], false), - // Test whether ordering equivalence works as expected - (vec![(col_d, option1)], true), - (vec![(col_d, option1), (col_b, option1)], true), - (vec![(col_d, option2), (col_b, option1)], false), - ( - vec![(col_e, option2), (col_f, option1), (col_g, option1)], - true, - ), - (vec![(col_e, option2), (col_f, option1)], true), - (vec![(col_e, option1), (col_f, option1)], false), - (vec![(col_e, option2), (col_b, option1)], false), - (vec![(col_e, option1), (col_b, option1)], false), - ( - vec![ - (col_d, option1), - (col_b, option1), - (col_d, option1), - (col_b, option1), - ], - true, - ), - ( - vec![ - (col_d, option1), - (col_b, option1), - (col_e, option2), - (col_f, option1), - ], - true, - ), - ( - vec![ - (col_d, option1), - (col_b, option1), - (col_e, option2), - (col_b, option1), - ], - true, - ), - ( - vec![ - (col_d, option1), - (col_b, option1), - (col_d, option2), - (col_b, option1), - ], - true, - ), - ( - vec![ - (col_d, option1), - (col_b, option1), - (col_e, option1), - (col_f, option1), - ], - false, - ), - ( - vec![ - (col_d, option1), - (col_b, option1), - (col_e, option1), - (col_b, option1), - ], - false, - ), - (vec![(col_d, option1), (col_e, option2)], true), - ]; - - for (cols, expected) in requirements { - let err_msg = format!("Error in test case:{cols:?}"); - let required = cols - .into_iter() - .map(|(col, options)| PhysicalSortExpr { - expr: Arc::new(col.clone()), - options, - }) - .collect::>(); - - let required = Some(&required[..]); - assert_eq!( - ordering_satisfy( - provided, - required, - || eq_properties.clone(), - || ordering_eq_properties.clone(), - ), - expected, - "{err_msg}" - ); - } - Ok(()) - } - - fn convert_to_requirement( - in_data: &[(&Column, Option)], - ) -> Vec { - in_data - .iter() - .map(|(col, options)| { - PhysicalSortRequirement::new(Arc::new((*col).clone()) as _, *options) - }) - .collect::>() - } - - #[test] - fn test_normalize_sort_reqs() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); - let col_d = &Column::new("d", 3); - let col_e = &Column::new("e", 4); - let col_f = &Column::new("f", 5); - let option1 = SortOptions { - descending: false, - nulls_first: false, - }; - let option2 = SortOptions { - descending: true, - nulls_first: true, - }; - // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function - let requirements = vec![ - (vec![(col_a, Some(option1))], vec![(col_a, Some(option1))]), - (vec![(col_a, Some(option2))], vec![(col_a, Some(option2))]), - (vec![(col_a, None)], vec![(col_a, Some(option1))]), - // Test whether equivalence works as expected - (vec![(col_c, Some(option1))], vec![(col_a, Some(option1))]), - (vec![(col_c, None)], vec![(col_a, Some(option1))]), - // Test whether ordering equivalence works as expected - ( - vec![(col_d, Some(option1)), (col_b, Some(option1))], - vec![(col_a, Some(option1))], - ), - ( - vec![(col_d, None), (col_b, None)], - vec![(col_a, Some(option1))], - ), - ( - vec![(col_e, Some(option2)), (col_f, Some(option1))], - vec![(col_a, Some(option1))], - ), - // We should be able to normalize in compatible requirements also (not exactly equal) - ( - vec![(col_e, Some(option2)), (col_f, None)], - vec![(col_a, Some(option1))], - ), - ( - vec![(col_e, None), (col_f, None)], - vec![(col_a, Some(option1))], - ), - ]; - - let (_test_schema, eq_properties, ordering_eq_properties) = create_test_params()?; - for (reqs, expected_normalized) in requirements.into_iter() { - let req = convert_to_requirement(&reqs); - let expected_normalized = convert_to_requirement(&expected_normalized); - - assert_eq!( - normalize_sort_requirements( - &req, - &eq_properties, - &ordering_eq_properties, - ), - expected_normalized - ); - } - Ok(()) - } - - #[test] - fn test_reassign_predicate_columns_in_list() { - let int_field = Field::new("should_not_matter", DataType::Int64, true); - let dict_field = Field::new( - "id", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - true, - ); - let schema_small = Arc::new(Schema::new(vec![dict_field.clone()])); - let schema_big = Arc::new(Schema::new(vec![int_field, dict_field])); - let pred = in_list( - Arc::new(Column::new_with_schema("id", &schema_big).unwrap()), - vec![lit(ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::from("2")), - ))], - &false, - &schema_big, - ) - .unwrap(); - - let actual = reassign_predicate_columns(pred, &schema_small, false).unwrap(); - - let expected = in_list( - Arc::new(Column::new_with_schema("id", &schema_small).unwrap()), - vec![lit(ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::from("2")), - ))], - &false, - &schema_small, - ) - .unwrap(); - - assert_eq!(actual.as_ref(), expected.as_any()); - } - - #[test] - fn test_normalize_expr_with_equivalence() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); - let _col_d = &Column::new("d", 3); - let _col_e = &Column::new("e", 4); - // Assume that column a and c are aliases. - let (_test_schema, eq_properties, _ordering_eq_properties) = - create_test_params()?; - - let col_a_expr = Arc::new(col_a.clone()) as Arc; - let col_b_expr = Arc::new(col_b.clone()) as Arc; - let col_c_expr = Arc::new(col_c.clone()) as Arc; - // Test cases for equivalence normalization, - // First entry in the tuple is argument, second entry is expected result after normalization. - let expressions = vec![ - // Normalized version of the column a and c should go to a (since a is head) - (&col_a_expr, &col_a_expr), - (&col_c_expr, &col_a_expr), - // Cannot normalize column b - (&col_b_expr, &col_b_expr), - ]; - for (expr, expected_eq) in expressions { - assert!( - expected_eq.eq(&eq_properties.normalize_expr(expr.clone())), - "error in test: expr: {expr:?}" - ); - } - - Ok(()) - } - - #[test] - fn test_normalize_sort_requirement_with_equivalence() -> Result<()> { - let col_a = &Column::new("a", 0); - let _col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); - let col_d = &Column::new("d", 3); - let _col_e = &Column::new("e", 4); - let option1 = SortOptions { - descending: false, - nulls_first: false, - }; - // Assume that column a and c are aliases. - let (_test_schema, eq_properties, _ordering_eq_properties) = - create_test_params()?; - - // Test cases for equivalence normalization - // First entry in the tuple is PhysicalExpr, second entry is its ordering, third entry is result after normalization. - let expressions = vec![ - (&col_a, Some(option1), &col_a, Some(option1)), - (&col_c, Some(option1), &col_a, Some(option1)), - (&col_c, None, &col_a, None), - // Cannot normalize column d, since it is not in equivalence properties. - (&col_d, Some(option1), &col_d, Some(option1)), - ]; - for (expr, sort_options, expected_col, expected_options) in - expressions.into_iter() - { - let expected = PhysicalSortRequirement::new( - Arc::new((*expected_col).clone()) as _, - expected_options, - ); - let arg = PhysicalSortRequirement::new( - Arc::new((*expr).clone()) as _, - sort_options, - ); - assert!( - expected.eq(&eq_properties.normalize_sort_requirement(arg.clone())), - "error in test: expr: {expr:?}, sort_options: {sort_options:?}" - ); - } - - Ok(()) - } - - #[test] - fn test_ordering_satisfy_different_lengths() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); - let col_d = &Column::new("d", 3); - let col_e = &Column::new("e", 4); - let test_schema = create_test_schema()?; - let option1 = SortOptions { - descending: false, - nulls_first: false, - }; - // Column a and c are aliases. - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - eq_properties.add_equal_conditions((col_a, col_c)); - - // Column a and e are ordering equivalent (e.g global ordering of the table can be described both as a ASC and e ASC.) - let mut ordering_eq_properties = OrderingEquivalenceProperties::new(test_schema); - ordering_eq_properties.add_equal_conditions(( - &vec![PhysicalSortExpr { - expr: Arc::new(col_a.clone()), - options: option1, - }], - &vec![PhysicalSortExpr { - expr: Arc::new(col_e.clone()), - options: option1, - }], - )); - let sort_req_a = PhysicalSortExpr { - expr: Arc::new((col_a).clone()) as _, - options: option1, - }; - let sort_req_b = PhysicalSortExpr { - expr: Arc::new((col_b).clone()) as _, - options: option1, - }; - let sort_req_c = PhysicalSortExpr { - expr: Arc::new((col_c).clone()) as _, - options: option1, - }; - let sort_req_d = PhysicalSortExpr { - expr: Arc::new((col_d).clone()) as _, - options: option1, - }; - let sort_req_e = PhysicalSortExpr { - expr: Arc::new((col_e).clone()) as _, - options: option1, - }; - - assert!(ordering_satisfy_concrete( - // After normalization would be a ASC, b ASC, d ASC - &[sort_req_a.clone(), sort_req_b.clone(), sort_req_d.clone()], - // After normalization would be a ASC, b ASC, d ASC - &[ - sort_req_c.clone(), - sort_req_b.clone(), - sort_req_a.clone(), - sort_req_d.clone(), - sort_req_e.clone(), - ], - || eq_properties.clone(), - || ordering_eq_properties.clone(), - )); - - assert!(!ordering_satisfy_concrete( - // After normalization would be a ASC, b ASC - &[sort_req_a.clone(), sort_req_b.clone()], - // After normalization would be a ASC, b ASC, d ASC - &[ - sort_req_c.clone(), - sort_req_b.clone(), - sort_req_a.clone(), - sort_req_d.clone(), - sort_req_e.clone(), - ], - || eq_properties.clone(), - || ordering_eq_properties.clone(), - )); - - assert!(!ordering_satisfy_concrete( - // After normalization would be a ASC, b ASC, d ASC - &[sort_req_a.clone(), sort_req_b.clone(), sort_req_d.clone()], - // After normalization would be a ASC, d ASC, b ASC - &[sort_req_c, sort_req_d, sort_req_a, sort_req_b, sort_req_e,], - || eq_properties.clone(), - || ordering_eq_properties.clone(), - )); - - Ok(()) - } - - #[test] - fn test_collect_columns() -> Result<()> { - let expr1 = Arc::new(Column::new("col1", 2)) as _; - let mut expected = HashSet::new(); - expected.insert(Column::new("col1", 2)); - assert_eq!(collect_columns(&expr1), expected); - - let expr2 = Arc::new(Column::new("col2", 5)) as _; - let mut expected = HashSet::new(); - expected.insert(Column::new("col2", 5)); - assert_eq!(collect_columns(&expr2), expected); - - let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _; - let mut expected = HashSet::new(); - expected.insert(Column::new("col1", 2)); - expected.insert(Column::new("col2", 5)); - assert_eq!(collect_columns(&expr3), expected); - Ok(()) - } - - #[test] - fn scatter_int() -> Result<()> { - let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); - let mask = BooleanArray::from(vec![true, true, false, false, true]); - - // the output array is expected to be the same length as the mask array - let expected = - Int32Array::from_iter(vec![Some(1), Some(10), None, None, Some(11)]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_int32_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) - } - - #[test] - fn scatter_int_end_with_false() -> Result<()> { - let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); - let mask = BooleanArray::from(vec![true, false, true, false, false, false]); - - // output should be same length as mask - let expected = - Int32Array::from_iter(vec![Some(1), None, Some(10), None, None, None]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_int32_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) - } - - #[test] - fn scatter_with_null_mask() -> Result<()> { - let truthy = Arc::new(Int32Array::from(vec![1, 10, 11])); - let mask: BooleanArray = vec![Some(false), None, Some(true), Some(true), None] - .into_iter() - .collect(); - - // output should treat nulls as though they are false - let expected = Int32Array::from_iter(vec![None, None, Some(1), Some(10), None]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_int32_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) - } - - #[test] - fn scatter_boolean() -> Result<()> { - let truthy = Arc::new(BooleanArray::from(vec![false, false, false, true])); - let mask = BooleanArray::from(vec![true, true, false, false, true]); - - // the output array is expected to be the same length as the mask array - let expected = BooleanArray::from_iter(vec![ - Some(false), - Some(false), - None, - None, - Some(false), - ]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_boolean_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) - } - - #[test] - fn test_get_indices_of_matching_sort_exprs_with_order_eq() -> Result<()> { - let sort_options = SortOptions::default(); - let sort_options_not = SortOptions::default().not(); - - let provided_sorts = [ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ]; - let required_columns = [Column::new("b", 1), Column::new("a", 0)]; - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ]); - let equal_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - let ordering_equal_properties = - OrderingEquivalenceProperties::new(Arc::new(schema)); - assert_eq!( - get_indices_of_matching_sort_exprs_with_order_eq( - &provided_sorts, - &required_columns, - &equal_properties, - &ordering_equal_properties, - ), - Some((vec![sort_options_not, sort_options], vec![0, 1])) - ); - - // required columns are provided in the equivalence classes - let provided_sorts = [PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }]; - let required_columns = [Column::new("b", 1), Column::new("a", 0)]; - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let equal_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - let mut ordering_equal_properties = - OrderingEquivalenceProperties::new(Arc::new(schema)); - ordering_equal_properties.add_equal_conditions(( - &vec![PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }], - &vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ], - )); - assert_eq!( - get_indices_of_matching_sort_exprs_with_order_eq( - &provided_sorts, - &required_columns, - &equal_properties, - &ordering_equal_properties, - ), - Some((vec![sort_options_not, sort_options], vec![0, 1])) - ); - - // not satisfied orders - let provided_sorts = [ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ]; - let required_columns = [Column::new("b", 1), Column::new("a", 0)]; - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let equal_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - let ordering_equal_properties = - OrderingEquivalenceProperties::new(Arc::new(schema)); - assert_eq!( - get_indices_of_matching_sort_exprs_with_order_eq( - &provided_sorts, - &required_columns, - &equal_properties, - &ordering_equal_properties, - ), - None - ); - - Ok(()) - } - - #[test] - fn test_normalize_ordering_equivalence_classes() -> Result<()> { - let sort_options = SortOptions::default(); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let mut equal_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - let mut expected_oeq = OrderingEquivalenceProperties::new(Arc::new(schema)); - - equal_properties - .add_equal_conditions((&Column::new("a", 0), &Column::new("c", 2))); - let head = vec![PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options, - }]; - let others = vec![vec![PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }]]; - let oeq_class = OrderingEquivalentClass::new(head, others); - - expected_oeq.add_equal_conditions(( - &vec![PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options, - }], - &vec![PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }], - )); - - let normalized_oeq_class = - oeq_class.normalize_with_equivalence_properties(&equal_properties); - let expected = expected_oeq.oeq_class().unwrap(); - assert!( - normalized_oeq_class.head().eq(expected.head()) - && normalized_oeq_class.others().eq(expected.others()) - ); - - Ok(()) - } - - #[test] - fn project_empty_output_ordering() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let orderings = find_orderings_of_exprs( - &[ - (Arc::new(Column::new("b", 1)), "b_new".to_string()), - (Arc::new(Column::new("a", 0)), "a_new".to_string()), - ], - Some(&[PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }]), - EquivalenceProperties::new(Arc::new(schema.clone())), - OrderingEquivalenceProperties::new(Arc::new(schema.clone())), - )?; - - assert_eq!( - vec![ - Some(PhysicalSortExpr { - expr: Arc::new(Column::new("b_new", 0)), - options: SortOptions::default(), - }), - None, - ], - orderings - ); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let orderings = find_orderings_of_exprs( - &[ - (Arc::new(Column::new("c", 2)), "c_new".to_string()), - (Arc::new(Column::new("b", 1)), "b_new".to_string()), - ], - Some(&[]), - EquivalenceProperties::new(Arc::new(schema.clone())), - OrderingEquivalenceProperties::new(Arc::new(schema)), - )?; - - assert_eq!(vec![None, None], orderings); - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs new file mode 100644 index 000000000000..0aee2af67fdd --- /dev/null +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -0,0 +1,856 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`LiteralGuarantee`] predicate analysis to determine if a column is a +//! constant. + +use crate::utils::split_disjunction; +use crate::{split_conjunction, PhysicalExpr}; +use datafusion_common::{Column, ScalarValue}; +use datafusion_expr::Operator; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +/// Represents a guarantee that must be true for a boolean expression to +/// evaluate to `true`. +/// +/// The guarantee takes the form of a column and a set of literal (constant) +/// [`ScalarValue`]s. For the expression to evaluate to `true`, the column *must +/// satisfy* the guarantee(s). +/// +/// To satisfy the guarantee, depending on [`Guarantee`], the values in the +/// column must either: +/// +/// 1. be ONLY one of that set +/// 2. NOT be ANY of that set +/// +/// # Uses `LiteralGuarantee`s +/// +/// `LiteralGuarantee`s can be used to simplify filter expressions and skip data +/// files (e.g. row groups in parquet files) by proving expressions can not +/// possibly evaluate to `true`. For example, if we have a guarantee that `a` +/// must be in (`1`) for a filter to evaluate to `true`, then we can skip any +/// partition where we know that `a` never has the value of `1`. +/// +/// **Important**: If a `LiteralGuarantee` is not satisfied, the relevant +/// expression is *guaranteed* to evaluate to `false` or `null`. **However**, +/// the opposite does not hold. Even if all `LiteralGuarantee`s are satisfied, +/// that does **not** guarantee that the predicate will actually evaluate to +/// `true`: it may still evaluate to `true`, `false` or `null`. +/// +/// # Creating `LiteralGuarantee`s +/// +/// Use [`LiteralGuarantee::analyze`] to extract literal guarantees from a +/// filter predicate. +/// +/// # Details +/// A guarantee can be one of two forms: +/// +/// 1. The column must be one the values for the predicate to be `true`. If the +/// column takes on any other value, the predicate can not evaluate to `true`. +/// For example, +/// `(a = 1)`, `(a = 1 OR a = 2) or `a IN (1, 2, 3)` +/// +/// 2. The column must NOT be one of the values for the predicate to be `true`. +/// If the column can ONLY take one of these values, the predicate can not +/// evaluate to `true`. For example, +/// `(a != 1)`, `(a != 1 AND a != 2)` or `a NOT IN (1, 2, 3)` +#[derive(Debug, Clone, PartialEq)] +pub struct LiteralGuarantee { + pub column: Column, + pub guarantee: Guarantee, + pub literals: HashSet, +} + +/// What is guaranteed about the values for a [`LiteralGuarantee`]? +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Guarantee { + /// Guarantee that the expression is `true` if `column` is one of the values. If + /// `column` is not one of the values, the expression can not be `true`. + In, + /// Guarantee that the expression is `true` if `column` is not ANY of the + /// values. If `column` only takes one of these values, the expression can + /// not be `true`. + NotIn, +} + +impl LiteralGuarantee { + /// Create a new instance of the guarantee if the provided operator is + /// supported. Returns None otherwise. See [`LiteralGuarantee::analyze`] to + /// create these structures from an predicate (boolean expression). + fn try_new<'a>( + column_name: impl Into, + guarantee: Guarantee, + literals: impl IntoIterator, + ) -> Option { + let literals: HashSet<_> = literals.into_iter().cloned().collect(); + + Some(Self { + column: Column::from_name(column_name), + guarantee, + literals, + }) + } + + /// Return a list of [`LiteralGuarantee`]s that must be satisfied for `expr` + /// to evaluate to `true`. + /// + /// If more than one `LiteralGuarantee` is returned, they must **all** hold + /// for the expression to possibly be `true`. If any is not satisfied, the + /// expression is guaranteed to be `null` or `false`. + /// + /// # Notes: + /// 1. `expr` must be a boolean expression or inlist expression. + /// 2. `expr` is not simplified prior to analysis. + pub fn analyze(expr: &Arc) -> Vec { + // split conjunction: AND AND ... + split_conjunction(expr) + .into_iter() + // for an `AND` conjunction to be true, all terms individually must be true + .fold(GuaranteeBuilder::new(), |builder, expr| { + if let Some(cel) = ColOpLit::try_new(expr) { + return builder.aggregate_conjunct(cel); + } else if let Some(inlist) = expr + .as_any() + .downcast_ref::() + { + // Only support single-column inlist currently, multi-column inlist is not supported + let col = inlist + .expr() + .as_any() + .downcast_ref::(); + let Some(col) = col else { + return builder; + }; + + let literals = inlist + .list() + .iter() + .map(|e| e.as_any().downcast_ref::()) + .collect::>>(); + let Some(literals) = literals else { + return builder; + }; + + let guarantee = if inlist.negated() { + Guarantee::NotIn + } else { + Guarantee::In + }; + + builder.aggregate_multi_conjunct( + col, + guarantee, + literals.iter().map(|e| e.value()), + ) + } else { + // split disjunction: OR OR ... + let disjunctions = split_disjunction(expr); + + // We are trying to add a guarantee that a column must be + // in/not in a particular set of values for the expression + // to evaluate to true. + // + // A disjunction is true, if at least one of the terms is be + // true. + // + // Thus, we can infer a guarantee if all terms are of the + // form `(col literal) OR (col literal) OR ...`. + // + // For example, we can infer that `a = 1 OR a = 2 OR a = 3` + // is guaranteed to be true ONLY if a is in (`1`, `2` or `3`). + // + // However, for something like `a = 1 OR a = 2 OR a < 0` we + // **can't** guarantee that the predicate is only true if a + // is in (`1`, `2`), as it could also be true if `a` were less + // than zero. + let terms = disjunctions + .iter() + .filter_map(|expr| ColOpLit::try_new(expr)) + .collect::>(); + + if terms.is_empty() { + return builder; + } + + // if not all terms are of the form (col literal), + // can't infer any guarantees + if terms.len() != disjunctions.len() { + return builder; + } + + // if all terms are 'col literal' with the same column + // and operation we can infer any guarantees + // + // For those like (a != foo AND (a != bar OR a != baz)). + // We can't combine the (a != bar OR a != baz) part, but + // it also doesn't invalidate our knowledge that a != + // foo is required for the expression to be true. + // So we can only create a multi value guarantee for `=` + // (or a single value). (e.g. ignore `a != foo OR a != bar`) + let first_term = &terms[0]; + if terms.iter().all(|term| { + term.col.name() == first_term.col.name() + && term.guarantee == Guarantee::In + }) { + builder.aggregate_multi_conjunct( + first_term.col, + Guarantee::In, + terms.iter().map(|term| term.lit.value()), + ) + } else { + // can't infer anything + builder + } + } + }) + .build() + } +} + +/// Combines conjuncts (aka terms `AND`ed together) into [`LiteralGuarantee`]s, +/// preserving insert order +#[derive(Debug, Default)] +struct GuaranteeBuilder<'a> { + /// List of guarantees that have been created so far + /// if we have determined a subsequent conjunct invalidates a guarantee + /// e.g. `a = foo AND a = bar` then the relevant guarantee will be None + guarantees: Vec>, + + /// Key is the (column name, guarantee type) + /// Value is the index into `guarantees` + map: HashMap<(&'a crate::expressions::Column, Guarantee), usize>, +} + +impl<'a> GuaranteeBuilder<'a> { + fn new() -> Self { + Default::default() + } + + /// Aggregate a new single `AND col literal` term to this builder + /// combining with existing guarantees if possible. + /// + /// # Examples + /// * `AND (a = 1)`: `a` is guaranteed to be 1 + /// * `AND (a != 1)`: a is guaranteed to not be 1 + fn aggregate_conjunct(self, col_op_lit: ColOpLit<'a>) -> Self { + self.aggregate_multi_conjunct( + col_op_lit.col, + col_op_lit.guarantee, + [col_op_lit.lit.value()], + ) + } + + /// Aggregates a new single column, multi literal term to ths builder + /// combining with previously known guarantees if possible. + /// + /// # Examples + /// For the following examples, we can guarantee the expression is `true` if: + /// * `AND (a = 1 OR a = 2 OR a = 3)`: a is in (1, 2, or 3) + /// * `AND (a IN (1,2,3))`: a is in (1, 2, or 3) + /// * `AND (a != 1 OR a != 2 OR a != 3)`: a is not in (1, 2, or 3) + /// * `AND (a NOT IN (1,2,3))`: a is not in (1, 2, or 3) + fn aggregate_multi_conjunct( + mut self, + col: &'a crate::expressions::Column, + guarantee: Guarantee, + new_values: impl IntoIterator, + ) -> Self { + let key = (col, guarantee); + if let Some(index) = self.map.get(&key) { + // already have a guarantee for this column + let entry = &mut self.guarantees[*index]; + + let Some(existing) = entry else { + // determined the previous guarantee for this column has been + // invalidated, nothing to do + return self; + }; + + // Combine conjuncts if we have `a != foo AND a != bar`. `a = foo + // AND a = bar` doesn't make logical sense so we don't optimize this + // case + match existing.guarantee { + // knew that the column could not be a set of values + // + // For example, if we previously had `a != 5` and now we see + // another `AND a != 6` we know that a must not be either 5 or 6 + // for the expression to be true + Guarantee::NotIn => { + let new_values: HashSet<_> = new_values.into_iter().collect(); + existing.literals.extend(new_values.into_iter().cloned()); + } + Guarantee::In => { + let intersection = new_values + .into_iter() + .filter(|new_value| existing.literals.contains(*new_value)) + .collect::>(); + // for an In guarantee, if the intersection is not empty, we can extend the guarantee + // e.g. `a IN (1,2,3) AND a IN (2,3,4)` is `a IN (2,3)` + // otherwise, we invalidate the guarantee + // e.g. `a IN (1,2,3) AND a IN (4,5,6)` is `a IN ()`, which is invalid + if !intersection.is_empty() { + existing.literals = intersection.into_iter().cloned().collect(); + } else { + // at least one was not, so invalidate the guarantee + *entry = None; + } + } + } + } else { + // This is a new guarantee + let new_values: HashSet<_> = new_values.into_iter().collect(); + + if let Some(guarantee) = + LiteralGuarantee::try_new(col.name(), guarantee, new_values) + { + // add it to the list of guarantees + self.guarantees.push(Some(guarantee)); + self.map.insert(key, self.guarantees.len() - 1); + } + } + + self + } + + /// Return all guarantees that have been created so far + fn build(self) -> Vec { + // filter out any guarantees that have been invalidated + self.guarantees.into_iter().flatten().collect() + } +} + +/// Represents a single `col [not]in literal` expression +struct ColOpLit<'a> { + col: &'a crate::expressions::Column, + guarantee: Guarantee, + lit: &'a crate::expressions::Literal, +} + +impl<'a> ColOpLit<'a> { + /// Returns Some(ColEqLit) if the expression is either: + /// 1. `col literal` + /// 2. `literal col` + /// 3. operator is `=` or `!=` + /// Returns None otherwise + fn try_new(expr: &'a Arc) -> Option { + let binary_expr = expr + .as_any() + .downcast_ref::()?; + + let (left, op, right) = ( + binary_expr.left().as_any(), + binary_expr.op(), + binary_expr.right().as_any(), + ); + let guarantee = match op { + Operator::Eq => Guarantee::In, + Operator::NotEq => Guarantee::NotIn, + _ => return None, + }; + // col literal + if let (Some(col), Some(lit)) = ( + left.downcast_ref::(), + right.downcast_ref::(), + ) { + Some(Self { + col, + guarantee, + lit, + }) + } + // literal col + else if let (Some(lit), Some(col)) = ( + left.downcast_ref::(), + right.downcast_ref::(), + ) { + Some(Self { + col, + guarantee, + lit, + }) + } else { + None + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::create_physical_expr; + use crate::execution_props::ExecutionProps; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::ToDFSchema; + use datafusion_expr::expr_fn::*; + use datafusion_expr::{lit, Expr}; + use std::sync::OnceLock; + + #[test] + fn test_literal() { + // a single literal offers no guarantee + test_analyze(lit(true), vec![]) + } + + #[test] + fn test_single() { + // a = "foo" + test_analyze(col("a").eq(lit("foo")), vec![in_guarantee("a", ["foo"])]); + // "foo" = a + test_analyze(lit("foo").eq(col("a")), vec![in_guarantee("a", ["foo"])]); + // a != "foo" + test_analyze( + col("a").not_eq(lit("foo")), + vec![not_in_guarantee("a", ["foo"])], + ); + // "foo" != a + test_analyze( + lit("foo").not_eq(col("a")), + vec![not_in_guarantee("a", ["foo"])], + ); + } + + #[test] + fn test_conjunction_single_column() { + // b = 1 AND b = 2. This is impossible. Ideally this expression could be simplified to false + test_analyze(col("b").eq(lit(1)).and(col("b").eq(lit(2))), vec![]); + // b = 1 AND b != 2 . In theory, this could be simplified to `b = 1`. + test_analyze( + col("b").eq(lit(1)).and(col("b").not_eq(lit(2))), + vec![ + // can only be true of b is 1 and b is not 2 (even though it is redundant) + in_guarantee("b", [1]), + not_in_guarantee("b", [2]), + ], + ); + // b != 1 AND b = 2. In theory, this could be simplified to `b = 2`. + test_analyze( + col("b").not_eq(lit(1)).and(col("b").eq(lit(2))), + vec![ + // can only be true of b is not 1 and b is is 2 (even though it is redundant) + not_in_guarantee("b", [1]), + in_guarantee("b", [2]), + ], + ); + // b != 1 AND b != 2 + test_analyze( + col("b").not_eq(lit(1)).and(col("b").not_eq(lit(2))), + vec![not_in_guarantee("b", [1, 2])], + ); + // b != 1 AND b != 2 and b != 3 + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").not_eq(lit(2))) + .and(col("b").not_eq(lit(3))), + vec![not_in_guarantee("b", [1, 2, 3])], + ); + // b != 1 AND b = 2 and b != 3. Can only be true if b is 2 and b is not in (1, 3) + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").eq(lit(2))) + .and(col("b").not_eq(lit(3))), + vec![not_in_guarantee("b", [1, 3]), in_guarantee("b", [2])], + ); + // b != 1 AND b != 2 and b = 3 (in theory could determine b = 3) + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").not_eq(lit(2))) + .and(col("b").eq(lit(3))), + vec![not_in_guarantee("b", [1, 2]), in_guarantee("b", [3])], + ); + // b != 1 AND b != 2 and b > 3 (to be true, b can't be either 1 or 2 + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").not_eq(lit(2))) + .and(col("b").gt(lit(3))), + vec![not_in_guarantee("b", [1, 2])], + ); + } + + #[test] + fn test_conjunction_multi_column() { + // a = "foo" AND b = 1 + test_analyze( + col("a").eq(lit("foo")).and(col("b").eq(lit(1))), + vec![ + // should find both column guarantees + in_guarantee("a", ["foo"]), + in_guarantee("b", [1]), + ], + ); + // a != "foo" AND b != 1 + test_analyze( + col("a").not_eq(lit("foo")).and(col("b").not_eq(lit(1))), + // should find both column guarantees + vec![not_in_guarantee("a", ["foo"]), not_in_guarantee("b", [1])], + ); + // a = "foo" AND a = "bar" + test_analyze( + col("a").eq(lit("foo")).and(col("a").eq(lit("bar"))), + // this predicate is impossible ( can't be both foo and bar), + vec![], + ); + // a = "foo" AND b != "bar" + test_analyze( + col("a").eq(lit("foo")).and(col("a").not_eq(lit("bar"))), + vec![in_guarantee("a", ["foo"]), not_in_guarantee("a", ["bar"])], + ); + // a != "foo" AND a != "bar" + test_analyze( + col("a").not_eq(lit("foo")).and(col("a").not_eq(lit("bar"))), + // know it isn't "foo" or "bar" + vec![not_in_guarantee("a", ["foo", "bar"])], + ); + // a != "foo" AND a != "bar" and a != "baz" + test_analyze( + col("a") + .not_eq(lit("foo")) + .and(col("a").not_eq(lit("bar"))) + .and(col("a").not_eq(lit("baz"))), + // know it isn't "foo" or "bar" or "baz" + vec![not_in_guarantee("a", ["foo", "bar", "baz"])], + ); + // a = "foo" AND a = "foo" + let expr = col("a").eq(lit("foo")); + test_analyze(expr.clone().and(expr), vec![in_guarantee("a", ["foo"])]); + // b > 5 AND b = 10 (should get an b = 10 guarantee) + test_analyze( + col("b").gt(lit(5)).and(col("b").eq(lit(10))), + vec![in_guarantee("b", [10])], + ); + // b > 10 AND b = 10 (this is impossible) + test_analyze( + col("b").gt(lit(10)).and(col("b").eq(lit(10))), + vec![ + // if b isn't 10, it can not be true (though the expression actually can never be true) + in_guarantee("b", [10]), + ], + ); + // a != "foo" and (a != "bar" OR a != "baz") + test_analyze( + col("a") + .not_eq(lit("foo")) + .and(col("a").not_eq(lit("bar")).or(col("a").not_eq(lit("baz")))), + // a is not foo (we can't represent other knowledge about a) + vec![not_in_guarantee("a", ["foo"])], + ); + } + + #[test] + fn test_conjunction_and_disjunction_single_column() { + // b != 1 AND (b > 2) + test_analyze( + col("b").not_eq(lit(1)).and(col("b").gt(lit(2))), + vec![ + // for the expression to be true, b can not be one + not_in_guarantee("b", [1]), + ], + ); + + // b = 1 AND (b = 2 OR b = 3). Could be simplified to false. + test_analyze( + col("b") + .eq(lit(1)) + .and(col("b").eq(lit(2)).or(col("b").eq(lit(3)))), + vec![ + // in theory, b must be 1 and one of 2,3 for this expression to be true + // which is a logical contradiction + ], + ); + } + + #[test] + fn test_disjunction_single_column() { + // b = 1 OR b = 2 + test_analyze( + col("b").eq(lit(1)).or(col("b").eq(lit(2))), + vec![in_guarantee("b", [1, 2])], + ); + // b != 1 OR b = 2 + test_analyze(col("b").not_eq(lit(1)).or(col("b").eq(lit(2))), vec![]); + // b = 1 OR b != 2 + test_analyze(col("b").eq(lit(1)).or(col("b").not_eq(lit(2))), vec![]); + // b != 1 OR b != 2 + test_analyze(col("b").not_eq(lit(1)).or(col("b").not_eq(lit(2))), vec![]); + // b != 1 OR b != 2 OR b = 3 -- in theory could guarantee that b = 3 + test_analyze( + col("b") + .not_eq(lit(1)) + .or(col("b").not_eq(lit(2))) + .or(lit("b").eq(lit(3))), + vec![], + ); + // b = 1 OR b = 2 OR b = 3 + test_analyze( + col("b") + .eq(lit(1)) + .or(col("b").eq(lit(2))) + .or(col("b").eq(lit(3))), + vec![in_guarantee("b", [1, 2, 3])], + ); + // b = 1 OR b = 2 OR b > 3 -- can't guarantee that the expression is only true if a is in (1, 2) + test_analyze( + col("b") + .eq(lit(1)) + .or(col("b").eq(lit(2))) + .or(lit("b").eq(lit(3))), + vec![], + ); + } + + #[test] + fn test_disjunction_multi_column() { + // a = "foo" OR b = 1 + test_analyze( + col("a").eq(lit("foo")).or(col("b").eq(lit(1))), + // no can't have a single column guarantee (if a = "foo" then b != 1) etc + vec![], + ); + // a != "foo" OR b != 1 + test_analyze( + col("a").not_eq(lit("foo")).or(col("b").not_eq(lit(1))), + // No single column guarantee + vec![], + ); + // a = "foo" OR a = "bar" + test_analyze( + col("a").eq(lit("foo")).or(col("a").eq(lit("bar"))), + vec![in_guarantee("a", ["foo", "bar"])], + ); + // a = "foo" OR a = "foo" + test_analyze( + col("a").eq(lit("foo")).or(col("a").eq(lit("foo"))), + vec![in_guarantee("a", ["foo"])], + ); + // a != "foo" OR a != "bar" + test_analyze( + col("a").not_eq(lit("foo")).or(col("a").not_eq(lit("bar"))), + // can't represent knowledge about a in this case + vec![], + ); + // a = "foo" OR a = "bar" OR a = "baz" + test_analyze( + col("a") + .eq(lit("foo")) + .or(col("a").eq(lit("bar"))) + .or(col("a").eq(lit("baz"))), + vec![in_guarantee("a", ["foo", "bar", "baz"])], + ); + // (a = "foo" OR a = "bar") AND (a = "baz)" + test_analyze( + (col("a").eq(lit("foo")).or(col("a").eq(lit("bar")))) + .and(col("a").eq(lit("baz"))), + // this could potentially be represented as 2 constraints with a more + // sophisticated analysis + vec![], + ); + // (a = "foo" OR a = "bar") AND (b = 1) + test_analyze( + (col("a").eq(lit("foo")).or(col("a").eq(lit("bar")))) + .and(col("b").eq(lit(1))), + vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [1])], + ); + // (a = "foo" OR a = "bar") OR (b = 1) + test_analyze( + col("a") + .eq(lit("foo")) + .or(col("a").eq(lit("bar"))) + .or(col("b").eq(lit(1))), + // can't represent knowledge about a or b in this case + vec![], + ); + } + + #[test] + fn test_single_inlist() { + // b IN (1, 2, 3) + test_analyze( + col("b").in_list(vec![lit(1), lit(2), lit(3)], false), + vec![in_guarantee("b", [1, 2, 3])], + ); + // b NOT IN (1, 2, 3) + test_analyze( + col("b").in_list(vec![lit(1), lit(2), lit(3)], true), + vec![not_in_guarantee("b", [1, 2, 3])], + ); + } + + #[test] + fn test_inlist_conjunction() { + // b IN (1, 2, 3) AND b IN (2, 3, 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").in_list(vec![lit(2), lit(3), lit(4)], false)), + vec![in_guarantee("b", [2, 3])], + ); + // b NOT IN (1, 2, 3) AND b IN (2, 3, 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").in_list(vec![lit(2), lit(3), lit(4)], false)), + vec![ + not_in_guarantee("b", [1, 2, 3]), + in_guarantee("b", [2, 3, 4]), + ], + ); + // b NOT IN (1, 2, 3) AND b NOT IN (2, 3, 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").in_list(vec![lit(2), lit(3), lit(4)], true)), + vec![not_in_guarantee("b", [1, 2, 3, 4])], + ); + // b IN (1, 2, 3) AND b = 4 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(4))), + vec![], + ); + // b IN (1, 2, 3) AND b = 2 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(2))), + vec![in_guarantee("b", [2])], + ); + // b IN (1, 2, 3) AND b != 2 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").not_eq(lit(2))), + vec![in_guarantee("b", [1, 2, 3]), not_in_guarantee("b", [2])], + ); + // b NOT IN (1, 2, 3) AND b != 4 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").not_eq(lit(4))), + vec![not_in_guarantee("b", [1, 2, 3, 4])], + ); + // b NOT IN (1, 2, 3) AND b != 2 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").not_eq(lit(2))), + vec![not_in_guarantee("b", [1, 2, 3])], + ); + } + + #[test] + fn test_inlist_with_disjunction() { + // b IN (1, 2, 3) AND (b = 3 OR b = 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(3)).or(col("b").eq(lit(4)))), + vec![in_guarantee("b", [3])], + ); + // b IN (1, 2, 3) AND (b = 4 OR b = 5) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(4)).or(col("b").eq(lit(5)))), + vec![], + ); + // b NOT IN (1, 2, 3) AND (b = 3 OR b = 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").eq(lit(3)).or(col("b").eq(lit(4)))), + vec![not_in_guarantee("b", [1, 2, 3]), in_guarantee("b", [3, 4])], + ); + // b IN (1, 2, 3) OR b = 2 + // TODO this should be in_guarantee("b", [1, 2, 3]) but currently we don't support to anylize this kind of disjunction. Only `ColOpLit OR ColOpLit` is supported. + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .or(col("b").eq(lit(2))), + vec![], + ); + // b IN (1, 2, 3) OR b != 3 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .or(col("b").not_eq(lit(3))), + vec![], + ); + } + + /// Tests that analyzing expr results in the expected guarantees + fn test_analyze(expr: Expr, expected: Vec) { + println!("Begin analyze of {expr}"); + let schema = schema(); + let physical_expr = logical2physical(&expr, &schema); + + let actual = LiteralGuarantee::analyze(&physical_expr); + assert_eq!( + expected, actual, + "expr: {expr}\ + \n\nexpected: {expected:#?}\ + \n\nactual: {actual:#?}\ + \n\nexpr: {expr:#?}\ + \n\nphysical_expr: {physical_expr:#?}" + ); + } + + /// Guarantee that the expression is true if the column is one of the specified values + fn in_guarantee<'a, I, S>(column: &str, literals: I) -> LiteralGuarantee + where + I: IntoIterator, + S: Into + 'a, + { + let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); + LiteralGuarantee::try_new(column, Guarantee::In, literals.iter()).unwrap() + } + + /// Guarantee that the expression is true if the column is NOT any of the specified values + fn not_in_guarantee<'a, I, S>(column: &str, literals: I) -> LiteralGuarantee + where + I: IntoIterator, + S: Into + 'a, + { + let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); + LiteralGuarantee::try_new(column, Guarantee::NotIn, literals.iter()).unwrap() + } + + /// Convert a logical expression to a physical expression (without any simplification, etc) + fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { + let df_schema = schema.clone().to_dfschema().unwrap(); + let execution_props = ExecutionProps::new(); + create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() + } + + // Schema for testing + fn schema() -> SchemaRef { + SCHEMA + .get_or_init(|| { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + ])) + }) + .clone() + } + + static SCHEMA: OnceLock = OnceLock::new(); +} diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs new file mode 100644 index 000000000000..64a62dc7820d --- /dev/null +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -0,0 +1,626 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod guarantee; +pub use guarantee::{Guarantee, LiteralGuarantee}; + +use std::borrow::{Borrow, Cow}; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +use crate::expressions::{BinaryExpr, Column}; +use crate::{PhysicalExpr, PhysicalSortExpr}; + +use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; +use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; +use arrow::datatypes::SchemaRef; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRewriter, VisitRecursion, +}; +use datafusion_common::Result; +use datafusion_expr::Operator; + +use itertools::Itertools; +use petgraph::graph::NodeIndex; +use petgraph::stable_graph::StableGraph; + +/// Assume the predicate is in the form of CNF, split the predicate to a Vec of PhysicalExprs. +/// +/// For example, split "a1 = a2 AND b1 <= b2 AND c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"] +pub fn split_conjunction( + predicate: &Arc, +) -> Vec<&Arc> { + split_impl(Operator::And, predicate, vec![]) +} + +/// Assume the predicate is in the form of DNF, split the predicate to a Vec of PhysicalExprs. +/// +/// For example, split "a1 = a2 OR b1 <= b2 OR c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"] +pub fn split_disjunction( + predicate: &Arc, +) -> Vec<&Arc> { + split_impl(Operator::Or, predicate, vec![]) +} + +fn split_impl<'a>( + operator: Operator, + predicate: &'a Arc, + mut exprs: Vec<&'a Arc>, +) -> Vec<&'a Arc> { + match predicate.as_any().downcast_ref::() { + Some(binary) if binary.op() == &operator => { + let exprs = split_impl(operator, binary.left(), exprs); + split_impl(operator, binary.right(), exprs) + } + Some(_) | None => { + exprs.push(predicate); + exprs + } + } +} + +/// This function maps back requirement after ProjectionExec +/// to the Executor for its input. +// Specifically, `ProjectionExec` changes index of `Column`s in the schema of its input executor. +// This function changes requirement given according to ProjectionExec schema to the requirement +// according to schema of input executor to the ProjectionExec. +// For instance, Column{"a", 0} would turn to Column{"a", 1}. Please note that this function assumes that +// name of the Column is unique. If we have a requirement such that Column{"a", 0}, Column{"a", 1}. +// This function will produce incorrect result (It will only emit single Column as a result). +pub fn map_columns_before_projection( + parent_required: &[Arc], + proj_exprs: &[(Arc, String)], +) -> Vec> { + let column_mapping = proj_exprs + .iter() + .filter_map(|(expr, name)| { + expr.as_any() + .downcast_ref::() + .map(|column| (name.clone(), column.clone())) + }) + .collect::>(); + parent_required + .iter() + .filter_map(|r| { + r.as_any() + .downcast_ref::() + .and_then(|c| column_mapping.get(c.name())) + }) + .map(|e| Arc::new(e.clone()) as _) + .collect() +} + +/// This function returns all `Arc`s inside the given +/// `PhysicalSortExpr` sequence. +pub fn convert_to_expr>( + sequence: impl IntoIterator, +) -> Vec> { + sequence + .into_iter() + .map(|elem| elem.borrow().expr.clone()) + .collect() +} + +/// This function finds the indices of `targets` within `items` using strict +/// equality. +pub fn get_indices_of_exprs_strict>>( + targets: impl IntoIterator, + items: &[Arc], +) -> Vec { + targets + .into_iter() + .filter_map(|target| items.iter().position(|e| e.eq(target.borrow()))) + .collect() +} + +#[derive(Clone, Debug)] +pub struct ExprTreeNode { + expr: Arc, + data: Option, + child_nodes: Vec>, +} + +impl ExprTreeNode { + pub fn new(expr: Arc) -> Self { + let children = expr.children(); + ExprTreeNode { + expr, + data: None, + child_nodes: children.into_iter().map(Self::new).collect_vec(), + } + } + + pub fn expression(&self) -> &Arc { + &self.expr + } + + pub fn children(&self) -> &[ExprTreeNode] { + &self.child_nodes + } +} + +impl TreeNode for ExprTreeNode { + fn children_nodes(&self) -> Vec> { + self.children().iter().map(Cow::Borrowed).collect() + } + + fn map_children(mut self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + self.child_nodes = self + .child_nodes + .into_iter() + .map(transform) + .collect::>>()?; + Ok(self) + } +} + +/// This struct facilitates the [TreeNodeRewriter] mechanism to convert a +/// [PhysicalExpr] tree into a DAEG (i.e. an expression DAG) by collecting +/// identical expressions in one node. Caller specifies the node type in the +/// DAEG via the `constructor` argument, which constructs nodes in the DAEG +/// from the [ExprTreeNode] ancillary object. +struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result> { + // The resulting DAEG (expression DAG). + graph: StableGraph, + // A vector of visited expression nodes and their corresponding node indices. + visited_plans: Vec<(Arc, NodeIndex)>, + // A function to convert an input expression node to T. + constructor: &'a F, +} + +impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter + for PhysicalExprDAEGBuilder<'a, T, F> +{ + type N = ExprTreeNode; + // This method mutates an expression node by transforming it to a physical expression + // and adding it to the graph. The method returns the mutated expression node. + fn mutate( + &mut self, + mut node: ExprTreeNode, + ) -> Result> { + // Get the expression associated with the input expression node. + let expr = &node.expr; + + // Check if the expression has already been visited. + let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) { + // If the expression has been visited, return the corresponding node index. + Some((_, idx)) => *idx, + // If the expression has not been visited, add a new node to the graph and + // add edges to its child nodes. Add the visited expression to the vector + // of visited expressions and return the newly created node index. + None => { + let node_idx = self.graph.add_node((self.constructor)(&node)?); + for expr_node in node.child_nodes.iter() { + self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0); + } + self.visited_plans.push((expr.clone(), node_idx)); + node_idx + } + }; + // Set the data field of the input expression node to the corresponding node index. + node.data = Some(node_idx); + // Return the mutated expression node. + Ok(node) + } +} + +// A function that builds a directed acyclic graph of physical expression trees. +pub fn build_dag( + expr: Arc, + constructor: &F, +) -> Result<(NodeIndex, StableGraph)> +where + F: Fn(&ExprTreeNode) -> Result, +{ + // Create a new expression tree node from the input expression. + let init = ExprTreeNode::new(expr); + // Create a new `PhysicalExprDAEGBuilder` instance. + let mut builder = PhysicalExprDAEGBuilder { + graph: StableGraph::::new(), + visited_plans: Vec::<(Arc, NodeIndex)>::new(), + constructor, + }; + // Use the builder to transform the expression tree node into a DAG. + let root = init.rewrite(&mut builder)?; + // Return a tuple containing the root node index and the DAG. + Ok((root.data.unwrap(), builder.graph)) +} + +/// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`]. +pub fn collect_columns(expr: &Arc) -> HashSet { + let mut columns = HashSet::::new(); + expr.apply(&mut |expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + if !columns.iter().any(|c| c.eq(column)) { + columns.insert(column.clone()); + } + } + Ok(VisitRecursion::Continue) + }) + // pre_visit always returns OK, so this will always too + .expect("no way to return error during recursion"); + columns +} + +/// Re-assign column indices referenced in predicate according to given schema. +/// This may be helpful when dealing with projections. +pub fn reassign_predicate_columns( + pred: Arc, + schema: &SchemaRef, + ignore_not_found: bool, +) -> Result> { + pred.transform_down(&|expr| { + let expr_any = expr.as_any(); + + if let Some(column) = expr_any.downcast_ref::() { + let index = match schema.index_of(column.name()) { + Ok(idx) => idx, + Err(_) if ignore_not_found => usize::MAX, + Err(e) => return Err(e.into()), + }; + return Ok(Transformed::Yes(Arc::new(Column::new( + column.name(), + index, + )))); + } + Ok(Transformed::No(expr)) + }) +} + +/// Reverses the ORDER BY expression, which is useful during equivalent window +/// expression construction. For instance, 'ORDER BY a ASC, NULLS LAST' turns into +/// 'ORDER BY a DESC, NULLS FIRST'. +pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec { + order_bys + .iter() + .map(|e| PhysicalSortExpr { + expr: e.expr.clone(), + options: !e.options, + }) + .collect() +} + +/// Scatter `truthy` array by boolean mask. When the mask evaluates `true`, next values of `truthy` +/// are taken, when the mask evaluates `false` values null values are filled. +/// +/// # Arguments +/// * `mask` - Boolean values used to determine where to put the `truthy` values +/// * `truthy` - All values of this array are to scatter according to `mask` into final result. +pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { + let truthy = truthy.to_data(); + + // update the mask so that any null values become false + // (SlicesIterator doesn't respect nulls) + let mask = and_kleene(mask, &is_not_null(mask)?)?; + + let mut mutable = MutableArrayData::new(vec![&truthy], true, mask.len()); + + // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to + // fill with falsy values + + // keep track of how much is filled + let mut filled = 0; + // keep track of current position we have in truthy array + let mut true_pos = 0; + + SlicesIterator::new(&mask).for_each(|(start, end)| { + // the gap needs to be filled with nulls + if start > filled { + mutable.extend_nulls(start - filled); + } + // fill with truthy values + let len = end - start; + mutable.extend(0, true_pos, true_pos + len); + true_pos += len; + filled = end; + }); + // the remaining part is falsy + if filled < mask.len() { + mutable.extend_nulls(mask.len() - filled); + } + + let data = mutable.freeze(); + Ok(make_array(data)) +} + +/// Merge left and right sort expressions, checking for duplicates. +pub fn merge_vectors( + left: &[PhysicalSortExpr], + right: &[PhysicalSortExpr], +) -> Vec { + left.iter() + .cloned() + .chain(right.iter().cloned()) + .unique() + .collect() +} + +#[cfg(test)] +mod tests { + use std::fmt::{Display, Formatter}; + use std::sync::Arc; + + use super::*; + use crate::expressions::{binary, cast, col, in_list, lit, Column, Literal}; + use crate::PhysicalSortExpr; + + use arrow_array::Int32Array; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::cast::{as_boolean_array, as_int32_array}; + use datafusion_common::{Result, ScalarValue}; + + use petgraph::visit::Bfs; + + #[derive(Clone)] + struct DummyProperty { + expr_type: String, + } + + /// This is a dummy node in the DAEG; it stores a reference to the actual + /// [PhysicalExpr] as well as a dummy property. + #[derive(Clone)] + struct PhysicalExprDummyNode { + pub expr: Arc, + pub property: DummyProperty, + } + + impl Display for PhysicalExprDummyNode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.expr) + } + } + + fn make_dummy_node(node: &ExprTreeNode) -> Result { + let expr = node.expression().clone(); + let dummy_property = if expr.as_any().is::() { + "Binary" + } else if expr.as_any().is::() { + "Column" + } else if expr.as_any().is::() { + "Literal" + } else { + "Other" + } + .to_owned(); + Ok(PhysicalExprDummyNode { + expr, + property: DummyProperty { + expr_type: dummy_property, + }, + }) + } + + #[test] + fn test_build_dag() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let expr = binary( + cast( + binary( + col("0", &schema)?, + Operator::Plus, + col("1", &schema)?, + &schema, + )?, + &schema, + DataType::Int64, + )?, + Operator::Gt, + binary( + cast(col("2", &schema)?, &schema, DataType::Int64)?, + Operator::Plus, + lit(ScalarValue::Int64(Some(10))), + &schema, + )?, + &schema, + )?; + let mut vector_dummy_props = vec![]; + let (root, graph) = build_dag(expr, &make_dummy_node)?; + let mut bfs = Bfs::new(&graph, root); + while let Some(node_index) = bfs.next(&graph) { + let node = &graph[node_index]; + vector_dummy_props.push(node.property.clone()); + } + + assert_eq!( + vector_dummy_props + .iter() + .filter(|property| property.expr_type == "Binary") + .count(), + 3 + ); + assert_eq!( + vector_dummy_props + .iter() + .filter(|property| property.expr_type == "Column") + .count(), + 3 + ); + assert_eq!( + vector_dummy_props + .iter() + .filter(|property| property.expr_type == "Literal") + .count(), + 1 + ); + assert_eq!( + vector_dummy_props + .iter() + .filter(|property| property.expr_type == "Other") + .count(), + 2 + ); + Ok(()) + } + + #[test] + fn test_convert_to_expr() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt64, false)]); + let sort_expr = vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: Default::default(), + }]; + assert!(convert_to_expr(&sort_expr)[0].eq(&sort_expr[0].expr)); + Ok(()) + } + + #[test] + fn test_get_indices_of_exprs_strict() { + let list1: Vec> = vec![ + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("c", 2)), + Arc::new(Column::new("d", 3)), + ]; + let list2: Vec> = vec![ + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("c", 2)), + Arc::new(Column::new("a", 0)), + ]; + assert_eq!(get_indices_of_exprs_strict(&list1, &list2), vec![2, 0, 1]); + assert_eq!(get_indices_of_exprs_strict(&list2, &list1), vec![1, 2, 0]); + } + + #[test] + fn test_reassign_predicate_columns_in_list() { + let int_field = Field::new("should_not_matter", DataType::Int64, true); + let dict_field = Field::new( + "id", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ); + let schema_small = Arc::new(Schema::new(vec![dict_field.clone()])); + let schema_big = Arc::new(Schema::new(vec![int_field, dict_field])); + let pred = in_list( + Arc::new(Column::new_with_schema("id", &schema_big).unwrap()), + vec![lit(ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::from("2")), + ))], + &false, + &schema_big, + ) + .unwrap(); + + let actual = reassign_predicate_columns(pred, &schema_small, false).unwrap(); + + let expected = in_list( + Arc::new(Column::new_with_schema("id", &schema_small).unwrap()), + vec![lit(ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::from("2")), + ))], + &false, + &schema_small, + ) + .unwrap(); + + assert_eq!(actual.as_ref(), expected.as_any()); + } + + #[test] + fn test_collect_columns() -> Result<()> { + let expr1 = Arc::new(Column::new("col1", 2)) as _; + let mut expected = HashSet::new(); + expected.insert(Column::new("col1", 2)); + assert_eq!(collect_columns(&expr1), expected); + + let expr2 = Arc::new(Column::new("col2", 5)) as _; + let mut expected = HashSet::new(); + expected.insert(Column::new("col2", 5)); + assert_eq!(collect_columns(&expr2), expected); + + let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _; + let mut expected = HashSet::new(); + expected.insert(Column::new("col1", 2)); + expected.insert(Column::new("col2", 5)); + assert_eq!(collect_columns(&expr3), expected); + Ok(()) + } + + #[test] + fn scatter_int() -> Result<()> { + let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); + let mask = BooleanArray::from(vec![true, true, false, false, true]); + + // the output array is expected to be the same length as the mask array + let expected = + Int32Array::from_iter(vec![Some(1), Some(10), None, None, Some(11)]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_int32_array(&result)?; + + assert_eq!(&expected, result); + Ok(()) + } + + #[test] + fn scatter_int_end_with_false() -> Result<()> { + let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); + let mask = BooleanArray::from(vec![true, false, true, false, false, false]); + + // output should be same length as mask + let expected = + Int32Array::from_iter(vec![Some(1), None, Some(10), None, None, None]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_int32_array(&result)?; + + assert_eq!(&expected, result); + Ok(()) + } + + #[test] + fn scatter_with_null_mask() -> Result<()> { + let truthy = Arc::new(Int32Array::from(vec![1, 10, 11])); + let mask: BooleanArray = vec![Some(false), None, Some(true), Some(true), None] + .into_iter() + .collect(); + + // output should treat nulls as though they are false + let expected = Int32Array::from_iter(vec![None, None, Some(1), Some(10), None]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_int32_array(&result)?; + + assert_eq!(&expected, result); + Ok(()) + } + + #[test] + fn scatter_boolean() -> Result<()> { + let truthy = Arc::new(BooleanArray::from(vec![false, false, false, true])); + let mask = BooleanArray::from(vec![true, true, false, false, true]); + + // the output array is expected to be the same length as the mask array + let expected = BooleanArray::from_iter(vec![ + Some(false), + Some(false), + None, + None, + Some(false), + ]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_boolean_array(&result)?; + + assert_eq!(&expected, result); + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index a00d32e201fb..665ceb70d658 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -21,22 +21,19 @@ use std::any::Any; use std::ops::Range; use std::sync::Arc; -use super::BuiltInWindowFunctionExpr; -use super::WindowExpr; -use crate::equivalence::OrderingEquivalenceBuilder; +use super::{BuiltInWindowFunctionExpr, WindowExpr}; use crate::expressions::PhysicalSortExpr; -use crate::utils::{convert_to_expr, get_indices_of_matching_exprs}; use crate::window::window_expr::{get_orderby_values, WindowFn}; use crate::window::{PartitionBatches, PartitionWindowAggStates, WindowState}; use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr}; + use arrow::array::{new_empty_array, ArrayRef}; use arrow::compute::SortOptions; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; use datafusion_common::utils::evaluate_partition_ranges; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::window_state::WindowAggState; -use datafusion_expr::window_state::WindowFrameContext; +use datafusion_expr::window_state::{WindowAggState, WindowFrameContext}; use datafusion_expr::WindowFrame; /// A window expr that takes the form of a [`BuiltInWindowFunctionExpr`]. @@ -75,16 +72,12 @@ impl BuiltInWindowExpr { /// If `self.expr` doesn't have an ordering, ordering equivalence properties /// are not updated. Otherwise, ordering equivalence properties are updated /// by the ordering of `self.expr`. - pub fn add_equal_orderings EquivalenceProperties>( - &self, - builder: &mut OrderingEquivalenceBuilder, - equal_properties: F, - ) { - let schema = builder.schema(); + pub fn add_equal_orderings(&self, eq_properties: &mut EquivalenceProperties) { + let schema = eq_properties.schema(); if let Some(fn_res_ordering) = self.expr.get_result_ordering(schema) { if self.partition_by.is_empty() { // In the absence of a PARTITION BY, ordering of `self.expr` is global: - builder.add_equal_conditions(vec![fn_res_ordering]); + eq_properties.add_new_orderings([vec![fn_res_ordering]]); } else { // If we have a PARTITION BY, built-in functions can not introduce // a global ordering unless the existing ordering is compatible @@ -92,23 +85,11 @@ impl BuiltInWindowExpr { // expressions and existing ordering expressions are equal (w.r.t. // set equality), we can prefix the ordering of `self.expr` with // the existing ordering. - let existing_ordering = builder.existing_ordering(); - let existing_ordering_exprs = convert_to_expr(existing_ordering); - // Get indices of the PARTITION BY expressions among input ordering expressions: - let pb_indices = get_indices_of_matching_exprs( - &self.partition_by, - &existing_ordering_exprs, - equal_properties, - ); - // Existing ordering should match exactly with PARTITION BY expressions. - // There should be no missing/extra entries in the existing ordering. - // Otherwise, prefixing wouldn't work. - if pb_indices.len() == self.partition_by.len() - && pb_indices.len() == existing_ordering.len() - { - let mut new_ordering = existing_ordering.to_vec(); - new_ordering.push(fn_res_ordering); - builder.add_equal_conditions(new_ordering); + let (mut ordering, _) = + eq_properties.find_longest_permutation(&self.partition_by); + if ordering.len() == self.partition_by.len() { + ordering.push(fn_res_ordering); + eq_properties.add_new_orderings([ordering]); } } } diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index 1a060817d2e1..7aa4f6536a6e 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -37,7 +37,7 @@ use std::sync::Arc; /// `nth_value` need the value. #[allow(rustdoc::private_intra_doc_links)] pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { - /// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be + /// Returns the aggregate expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -60,8 +60,10 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { fn evaluate_args(&self, batch: &RecordBatch) -> Result> { self.expressions() .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect() } diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index f55f1600b9ca..7ee736ce9caa 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -23,7 +23,7 @@ use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::compute::cast; use arrow::datatypes::{DataType, Field}; -use datafusion_common::ScalarValue; +use datafusion_common::{arrow_datafusion_err, ScalarValue}; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::PartitionEvaluator; use std::any::Any; @@ -139,9 +139,10 @@ fn create_empty_array( let array = value .as_ref() .map(|scalar| scalar.to_array_of_size(size)) + .transpose()? .unwrap_or_else(|| new_null_array(data_type, size)); if array.data_type() != data_type { - cast(&array, data_type).map_err(DataFusionError::ArrowError) + cast(&array, data_type).map_err(|e| arrow_datafusion_err!(e)) } else { Ok(array) } @@ -171,10 +172,10 @@ fn shift_with_default_value( // Concatenate both arrays, add nulls after if shift > 0 else before if offset > 0 { concat(&[default_values.as_ref(), slice.as_ref()]) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } else { concat(&[slice.as_ref(), default_values.as_ref()]) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } } } diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 262a50969b82..b3c89122ebad 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -15,21 +15,24 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions for `first_value`, `last_value`, and `nth_value` -//! that can evaluated at runtime during query execution +//! Defines physical expressions for `FIRST_VALUE`, `LAST_VALUE`, and `NTH_VALUE` +//! functions that can be evaluated at run time during query execution. + +use std::any::Any; +use std::cmp::Ordering; +use std::ops::Range; +use std::sync::Arc; use crate::window::window_expr::{NthValueKind, NthValueState}; use crate::window::BuiltInWindowFunctionExpr; use crate::PhysicalExpr; + use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::window_state::WindowAggState; use datafusion_expr::PartitionEvaluator; -use std::any::Any; -use std::ops::Range; -use std::sync::Arc; /// nth_value expression #[derive(Debug)] @@ -77,17 +80,17 @@ impl NthValue { n: u32, ) -> Result { match n { - 0 => exec_err!("nth_value expect n to be > 0"), + 0 => exec_err!("NTH_VALUE expects n to be non-zero"), _ => Ok(Self { name: name.into(), expr, data_type, - kind: NthValueKind::Nth(n), + kind: NthValueKind::Nth(n as i64), }), } } - /// Get nth_value kind + /// Get the NTH_VALUE kind pub fn get_kind(&self) -> NthValueKind { self.kind } @@ -125,7 +128,7 @@ impl BuiltInWindowFunctionExpr for NthValue { let reversed_kind = match self.kind { NthValueKind::First => NthValueKind::Last, NthValueKind::Last => NthValueKind::First, - NthValueKind::Nth(_) => return None, + NthValueKind::Nth(idx) => NthValueKind::Nth(-idx), }; Some(Arc::new(Self { name: self.name.clone(), @@ -143,16 +146,17 @@ pub(crate) struct NthValueEvaluator { } impl PartitionEvaluator for NthValueEvaluator { - /// When the window frame has a fixed beginning (e.g UNBOUNDED - /// PRECEDING), for some functions such as FIRST_VALUE, LAST_VALUE and - /// NTH_VALUE we can memoize result. Once result is calculated it - /// will always stay same. Hence, we do not need to keep past data - /// as we process the entire dataset. This feature enables us to - /// prune rows from table. The default implementation does nothing + /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING), + /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we + /// can memoize the result. Once result is calculated, it will always stay + /// same. Hence, we do not need to keep past data as we process the entire + /// dataset. fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> { let out = &state.out_col; let size = out.len(); - let (is_prunable, is_last) = match self.state.kind { + let mut buffer_size = 1; + // Decide if we arrived at a final result yet: + let (is_prunable, is_reverse_direction) = match self.state.kind { NthValueKind::First => { let n_range = state.window_frame_range.end - state.window_frame_range.start; @@ -162,16 +166,30 @@ impl PartitionEvaluator for NthValueEvaluator { NthValueKind::Nth(n) => { let n_range = state.window_frame_range.end - state.window_frame_range.start; - (n_range >= (n as usize) && size >= (n as usize), false) + match n.cmp(&0) { + Ordering::Greater => { + (n_range >= (n as usize) && size > (n as usize), false) + } + Ordering::Less => { + let reverse_index = (-n) as usize; + buffer_size = reverse_index; + // Negative index represents reverse direction. + (n_range >= reverse_index, true) + } + Ordering::Equal => { + // The case n = 0 is not valid for the NTH_VALUE function. + unreachable!(); + } + } } }; if is_prunable { - if self.state.finalized_result.is_none() && !is_last { + if self.state.finalized_result.is_none() && !is_reverse_direction { let result = ScalarValue::try_from_array(out, size - 1)?; self.state.finalized_result = Some(result); } state.window_frame_range.start = - state.window_frame_range.end.saturating_sub(1); + state.window_frame_range.end.saturating_sub(buffer_size); } Ok(()) } @@ -195,12 +213,33 @@ impl PartitionEvaluator for NthValueEvaluator { NthValueKind::First => ScalarValue::try_from_array(arr, range.start), NthValueKind::Last => ScalarValue::try_from_array(arr, range.end - 1), NthValueKind::Nth(n) => { - // We are certain that n > 0. - let index = (n as usize) - 1; - if index >= n_range { - ScalarValue::try_from(arr.data_type()) - } else { - ScalarValue::try_from_array(arr, range.start + index) + match n.cmp(&0) { + Ordering::Greater => { + // SQL indices are not 0-based. + let index = (n as usize) - 1; + if index >= n_range { + // Outside the range, return NULL: + ScalarValue::try_from(arr.data_type()) + } else { + ScalarValue::try_from_array(arr, range.start + index) + } + } + Ordering::Less => { + let reverse_index = (-n) as usize; + if n_range >= reverse_index { + ScalarValue::try_from_array( + arr, + range.start + n_range - reverse_index, + ) + } else { + // Outside the range, return NULL: + ScalarValue::try_from(arr.data_type()) + } + } + Ordering::Equal => { + // The case n = 0 is not valid for the NTH_VALUE function. + unreachable!(); + } } } } diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs index 49aac0877ab3..f5442e1b0fee 100644 --- a/datafusion/physical-expr/src/window/ntile.rs +++ b/datafusion/physical-expr/src/window/ntile.rs @@ -96,8 +96,9 @@ impl PartitionEvaluator for NtileEvaluator { ) -> Result { let num_rows = num_rows as u64; let mut vec: Vec = Vec::new(); + let n = u64::min(self.n, num_rows); for i in 0..num_rows { - let res = i * self.n / num_rows; + let res = i * n / num_rows; vec.push(res + 1) } Ok(Arc::new(UInt64Array::from(vec))) diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 9bc36728f46e..86af5b322133 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -141,9 +141,16 @@ impl PartitionEvaluator for RankEvaluator { // There is no argument, values are order by column values (where rank is calculated) let range_columns = values; let last_rank_data = get_row_at_idx(range_columns, row_idx)?; - let empty = self.state.last_rank_data.is_empty(); - if empty || self.state.last_rank_data != last_rank_data { - self.state.last_rank_data = last_rank_data; + let new_rank_encountered = + if let Some(state_last_rank_data) = &self.state.last_rank_data { + // if rank data changes, new rank is encountered + state_last_rank_data != &last_rank_data + } else { + // First rank seen + true + }; + if new_rank_encountered { + self.state.last_rank_data = Some(last_rank_data); self.state.last_rank_boundary += self.state.current_group_count; self.state.current_group_count = 1; self.state.n_rank += 1; diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index da67bcabee0b..548fae75bd97 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -15,7 +15,13 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::fmt::Debug; +use std::ops::Range; +use std::sync::Arc; + use crate::{PhysicalExpr, PhysicalSortExpr}; + use arrow::array::{new_empty_array, Array, ArrayRef}; use arrow::compute::kernels::sort::SortColumn; use arrow::compute::SortOptions; @@ -25,13 +31,9 @@ use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::window_state::{ PartitionBatchState, WindowAggState, WindowFrameContext, }; -use datafusion_expr::PartitionEvaluator; -use datafusion_expr::{Accumulator, WindowFrame}; +use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame}; + use indexmap::IndexMap; -use std::any::Any; -use std::fmt::Debug; -use std::ops::Range; -use std::sync::Arc; /// Common trait for [window function] implementations /// @@ -59,7 +61,7 @@ use std::sync::Arc; /// [`PlainAggregateWindowExpr`]: crate::window::PlainAggregateWindowExpr /// [`SlidingAggregateWindowExpr`]: crate::window::SlidingAggregateWindowExpr pub trait WindowExpr: Send + Sync + Debug { - /// Returns the window expression as [`Any`](std::any::Any) so that it can be + /// Returns the window expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -82,8 +84,10 @@ pub trait WindowExpr: Send + Sync + Debug { fn evaluate_args(&self, batch: &RecordBatch) -> Result> { self.expressions() .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect() } @@ -270,7 +274,7 @@ pub enum WindowFn { #[derive(Debug, Clone, Default)] pub struct RankState { /// The last values for rank as these values change, we increase n_rank - pub last_rank_data: Vec, + pub last_rank_data: Option>, /// The index where last_rank_boundary is started pub last_rank_boundary: usize, /// Keep the number of entries in current rank @@ -290,7 +294,7 @@ pub struct NumRowsState { pub enum NthValueKind { First, Last, - Nth(u32), + Nth(i64), } #[derive(Debug, Clone)] diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index ace6f5d95483..6c761fc9687c 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -19,9 +19,9 @@ name = "datafusion-physical-plan" description = "Physical (ExecutionPlan) implementations for DataFusion query engine" keywords = ["arrow", "query", "sql"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -38,26 +38,26 @@ arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } arrow-schema = { workspace = true } -async-trait = "0.1.41" +async-trait = { workspace = true } chrono = { version = "0.4.23", default-features = false } -datafusion-common = { path = "../common", version = "31.0.0", default-features = false } -datafusion-execution = { path = "../execution", version = "31.0.0" } -datafusion-expr = { path = "../expr", version = "31.0.0" } -datafusion-physical-expr = { path = "../physical-expr", version = "31.0.0" } -futures = "0.3" +datafusion-common = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-physical-expr = { workspace = true } +futures = { workspace = true } half = { version = "2.1", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } -indexmap = "2.0.0" -itertools = { version = "0.11", features = ["use_std"] } -log = "^0.4" +indexmap = { workspace = true } +itertools = { version = "0.12", features = ["use_std"] } +log = { workspace = true } once_cell = "1.18.0" -parking_lot = "0.12" +parking_lot = { workspace = true } pin-project-lite = "^0.2.7" -rand = "0.8" +rand = { workspace = true } tokio = { version = "1.28", features = ["sync", "fs", "parking_lot"] } uuid = { version = "^1.2", features = ["v4"] } [dev-dependencies] -rstest = "0.18.0" +rstest = { workspace = true } termtree = "0.4.1" tokio = { version = "1.28", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } diff --git a/datafusion/physical-plan/README.md b/datafusion/physical-plan/README.md new file mode 100644 index 000000000000..366a6b555150 --- /dev/null +++ b/datafusion/physical-plan/README.md @@ -0,0 +1,27 @@ + + +# DataFusion Common + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion that contains the `ExecutionPlan` trait and the various implementations of that +trait for built in operators such as filters, projections, joins, aggregations, etc. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 746537557d46..10ff9edb8912 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -22,9 +22,9 @@ use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; use arrow_array::{Array, ArrayRef}; use arrow_schema::{DataType, SchemaRef}; +use datafusion_common::hash_utils::create_hashes; use datafusion_common::{DataFusionError, Result}; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; -use datafusion_physical_expr::hash_utils::create_hashes; use datafusion_physical_expr::EmitTo; use hashbrown::raw::RawTable; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index b6e4b0a44dec..0b94dd01cfd4 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -17,35 +17,38 @@ //! Aggregates functionalities +use std::any::Any; +use std::sync::Arc; + +use super::DisplayAs; use crate::aggregates::{ no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream, + topk_stream::GroupedTopKAggregateStream, }; + use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::windows::get_ordered_partition_by_indices; use crate::{ - DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, + DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, Partitioning, SendableRecordBatchStream, Statistics, }; use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion_common::utils::longest_consecutive_prefix; +use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_expr::Accumulator; use datafusion_physical_expr::{ - equivalence::project_equivalence_properties, - expressions::Column, - normalize_out_expr_with_columns_map, physical_exprs_contains, reverse_order_bys, - utils::{convert_to_expr, get_indices_of_matching_exprs}, - AggregateExpr, LexOrdering, LexOrderingReq, OrderingEquivalenceProperties, - PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, + aggregate::is_order_sensitive, + equivalence::{collapse_lex_req, ProjectionMapping}, + expressions::{Column, FirstValue, LastValue, Max, Min, UnKnownColumn}, + physical_exprs_contains, reverse_order_bys, AggregateExpr, EquivalenceProperties, + LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; use itertools::Itertools; -use std::any::Any; -use std::collections::HashMap; -use std::sync::Arc; mod group_values; mod no_grouping; @@ -54,16 +57,8 @@ mod row_hash; mod topk; mod topk_stream; -use crate::aggregates::topk_stream::GroupedTopKAggregateStream; pub use datafusion_expr::AggregateFunction; -use datafusion_physical_expr::aggregate::is_order_sensitive; pub use datafusion_physical_expr::expressions::create_aggregate_expr; -use datafusion_physical_expr::expressions::{Max, Min}; -use datafusion_physical_expr::utils::{ - get_finer_ordering, ordering_satisfy_requirement_concrete, -}; - -use super::DisplayAs; /// Hash aggregate modes #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -105,34 +100,6 @@ impl AggregateMode { } } -/// Group By expression modes -/// -/// `PartiallyOrdered` and `FullyOrdered` are used to reason about -/// when certain group by keys will never again be seen (and thus can -/// be emitted by the grouping operator). -/// -/// Specifically, each distinct combination of the relevant columns -/// are contiguous in the input, and once a new combination is seen -/// previous combinations are guaranteed never to appear again -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum GroupByOrderMode { - /// The input is known to be ordered by a preset (prefix but - /// possibly reordered) of the expressions in the `GROUP BY` clause. - /// - /// For example, if the input is ordered by `a, b, c` and we group - /// by `b, a, d`, `PartiallyOrdered` means a subset of group `b, - /// a, d` defines a preset for the existing ordering, in this case - /// `a, b`. - PartiallyOrdered, - /// The input is known to be ordered by *all* the expressions in the - /// `GROUP BY` clause. - /// - /// For example, if the input is ordered by `a, b, c, d` and we group by b, a, - /// `Ordered` means that all of the of group by expressions appear - /// as a preset for the existing ordering, in this case `a, b`. - FullyOrdered, -} - /// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET) /// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b] /// and a single group [false, false]. @@ -140,6 +107,7 @@ pub enum GroupByOrderMode { /// into multiple groups, using null expressions to align each group. /// For example, with a group by clause `GROUP BY GROUPING SET ((a,b),(a),(b))` the planner should /// create a `PhysicalGroupBy` like +/// ```text /// PhysicalGroupBy { /// expr: [(col(a), a), (col(b), b)], /// null_expr: [(NULL, a), (NULL, b)], @@ -149,6 +117,7 @@ pub enum GroupByOrderMode { /// [true, false] // (b) <=> (NULL, b) /// ] /// } +/// ``` #[derive(Clone, Debug, Default)] pub struct PhysicalGroupBy { /// Distinct (Physical Expr, Alias) in the grouping set @@ -216,6 +185,23 @@ impl PhysicalGroupBy { pub fn is_single(&self) -> bool { self.null_expr.is_empty() } + + /// Calculate GROUP BY expressions according to input schema. + pub fn input_exprs(&self) -> Vec> { + self.expr + .iter() + .map(|(expr, _alias)| expr.clone()) + .collect() + } + + /// Return grouping expressions as they occur in the output schema. + pub fn output_exprs(&self) -> Vec> { + self.expr + .iter() + .enumerate() + .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _) + .collect() + } } impl PartialEq for PhysicalGroupBy { @@ -252,18 +238,6 @@ impl From for SendableRecordBatchStream { } } -/// This object encapsulates ordering-related information on GROUP BY columns. -#[derive(Debug, Clone)] -pub(crate) struct AggregationOrdering { - /// Specifies whether the GROUP BY columns are partially or fully ordered. - mode: GroupByOrderMode, - /// Stores indices such that when we iterate with these indices, GROUP BY - /// expressions match input ordering. - order_indices: Vec, - /// Actual ordering information of the GROUP BY columns. - ordering: LexOrdering, -} - /// Hash aggregate execution plan #[derive(Debug)] pub struct AggregateExec { @@ -275,8 +249,6 @@ pub struct AggregateExec { aggr_expr: Vec>, /// FILTER (WHERE clause) expression for each aggregate expression filter_expr: Vec>>, - /// (ORDER BY clause) expression for each aggregate expression - order_by_expr: Vec>, /// Set if the output of this aggregation is truncated by a upstream sort/limit clause limit: Option, /// Input plan, could be a partial aggregate or the input to the aggregate @@ -285,321 +257,20 @@ pub struct AggregateExec { schema: SchemaRef, /// Input schema before any aggregation is applied. For partial aggregate this will be the /// same as input.schema() but for the final aggregate it will be the same as the input - /// to the partial aggregate + /// to the partial aggregate, i.e., partial and final aggregates have same `input_schema`. + /// We need the input schema of partial aggregate to be able to deserialize aggregate + /// expressions from protobuf for final aggregate. pub input_schema: SchemaRef, - /// The columns map used to normalize out expressions like Partitioning and PhysicalSortExpr - /// The key is the column from the input schema and the values are the columns from the output schema - columns_map: HashMap>, - /// Execution Metrics + /// The mapping used to normalize expressions like Partitioning and + /// PhysicalSortExpr that maps input to output + projection_mapping: ProjectionMapping, + /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// Stores mode and output ordering information for the `AggregateExec`. - aggregation_ordering: Option, - required_input_ordering: Option, -} - -/// Calculates the working mode for `GROUP BY` queries. -/// - If no GROUP BY expression has an ordering, returns `None`. -/// - If some GROUP BY expressions have an ordering, returns `Some(GroupByOrderMode::PartiallyOrdered)`. -/// - If all GROUP BY expressions have orderings, returns `Some(GroupByOrderMode::Ordered)`. -fn get_working_mode( - input: &Arc, - group_by: &PhysicalGroupBy, -) -> Option<(GroupByOrderMode, Vec)> { - if !group_by.is_single() { - // We do not currently support streaming execution if we have more - // than one group (e.g. we have grouping sets). - return None; - }; - - let output_ordering = input.output_ordering().unwrap_or(&[]); - // Since direction of the ordering is not important for GROUP BY columns, - // we convert PhysicalSortExpr to PhysicalExpr in the existing ordering. - let ordering_exprs = convert_to_expr(output_ordering); - let groupby_exprs = group_by - .expr - .iter() - .map(|(item, _)| item.clone()) - .collect::>(); - // Find where each expression of the GROUP BY clause occurs in the existing - // ordering (if it occurs): - let mut ordered_indices = - get_indices_of_matching_exprs(&groupby_exprs, &ordering_exprs, || { - input.equivalence_properties() - }); - ordered_indices.sort(); - // Find out how many expressions of the existing ordering define ordering - // for expressions in the GROUP BY clause. For example, if the input is - // ordered by a, b, c, d and we group by b, a, d; the result below would be. - // 2, meaning 2 elements (a, b) among the GROUP BY columns define ordering. - let first_n = longest_consecutive_prefix(ordered_indices); - if first_n == 0 { - // No GROUP by columns are ordered, we can not do streaming execution. - return None; - } - let ordered_exprs = ordering_exprs[0..first_n].to_vec(); - // Find indices for the GROUP BY expressions such that when we iterate with - // these indices, we would match existing ordering. For the example above, - // this would produce 1, 0; meaning 1st and 0th entries (a, b) among the - // GROUP BY expressions b, a, d match input ordering. - let ordered_group_by_indices = - get_indices_of_matching_exprs(&ordered_exprs, &groupby_exprs, || { - input.equivalence_properties() - }); - Some(if first_n == group_by.expr.len() { - (GroupByOrderMode::FullyOrdered, ordered_group_by_indices) - } else { - (GroupByOrderMode::PartiallyOrdered, ordered_group_by_indices) - }) -} - -/// This function gathers the ordering information for the GROUP BY columns. -fn calc_aggregation_ordering( - input: &Arc, - group_by: &PhysicalGroupBy, -) -> Option { - get_working_mode(input, group_by).map(|(mode, order_indices)| { - let existing_ordering = input.output_ordering().unwrap_or(&[]); - let out_group_expr = output_group_expr_helper(group_by); - // Calculate output ordering information for the operator: - let out_ordering = order_indices - .iter() - .zip(existing_ordering) - .map(|(idx, input_col)| PhysicalSortExpr { - expr: out_group_expr[*idx].clone(), - options: input_col.options, - }) - .collect::>(); - AggregationOrdering { - mode, - order_indices, - ordering: out_ordering, - } - }) -} - -/// This function returns grouping expressions as they occur in the output schema. -fn output_group_expr_helper(group_by: &PhysicalGroupBy) -> Vec> { - // Update column indices. Since the group by columns come first in the output schema, their - // indices are simply 0..self.group_expr(len). - group_by - .expr() - .iter() - .enumerate() - .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _) - .collect() -} - -/// This function returns the ordering requirement of the first non-reversible -/// order-sensitive aggregate function such as ARRAY_AGG. This requirement serves -/// as the initial requirement while calculating the finest requirement among all -/// aggregate functions. If this function returns `None`, it means there is no -/// hard ordering requirement for the aggregate functions (in terms of direction). -/// Then, we can generate two alternative requirements with opposite directions. -fn get_init_req( - aggr_expr: &[Arc], - order_by_expr: &[Option], -) -> Option { - for (aggr_expr, fn_reqs) in aggr_expr.iter().zip(order_by_expr.iter()) { - // If the aggregation function is a non-reversible order-sensitive function - // and there is a hard requirement, choose first such requirement: - if is_order_sensitive(aggr_expr) - && aggr_expr.reverse_expr().is_none() - && fn_reqs.is_some() - { - return fn_reqs.clone(); - } - } - None -} - -/// This function gets the finest ordering requirement among all the aggregation -/// functions. If requirements are conflicting, (i.e. we can not compute the -/// aggregations in a single [`AggregateExec`]), the function returns an error. -fn get_finest_requirement< - F: Fn() -> EquivalenceProperties, - F2: Fn() -> OrderingEquivalenceProperties, ->( - aggr_expr: &mut [Arc], - order_by_expr: &mut [Option], - eq_properties: F, - ordering_eq_properties: F2, -) -> Result> { - let mut finest_req = get_init_req(aggr_expr, order_by_expr); - for (aggr_expr, fn_req) in aggr_expr.iter_mut().zip(order_by_expr.iter_mut()) { - let fn_req = if let Some(fn_req) = fn_req { - fn_req - } else { - continue; - }; - if let Some(finest_req) = &mut finest_req { - if let Some(finer) = get_finer_ordering( - finest_req, - fn_req, - &eq_properties, - &ordering_eq_properties, - ) { - *finest_req = finer.to_vec(); - continue; - } - // If an aggregate function is reversible, analyze whether its reverse - // direction is compatible with existing requirements: - if let Some(reverse) = aggr_expr.reverse_expr() { - let fn_req_reverse = reverse_order_bys(fn_req); - if let Some(finer) = get_finer_ordering( - finest_req, - &fn_req_reverse, - &eq_properties, - &ordering_eq_properties, - ) { - // We need to update `aggr_expr` with its reverse, since only its - // reverse requirement is compatible with existing requirements: - *aggr_expr = reverse; - *finest_req = finer.to_vec(); - *fn_req = fn_req_reverse; - continue; - } - } - // If neither of the requirements satisfy the other, this means - // requirements are conflicting. Currently, we do not support - // conflicting requirements. - return not_impl_err!( - "Conflicting ordering requirements in aggregate functions is not supported" - ); - } else { - finest_req = Some(fn_req.clone()); - } - } - Ok(finest_req) -} - -/// Calculate the required input ordering for the [`AggregateExec`] by considering -/// ordering requirements of order-sensitive aggregation functions. -fn calc_required_input_ordering( - input: &Arc, - aggr_exprs: &mut [Arc], - order_by_exprs: &mut [Option], - aggregator_reqs: LexOrderingReq, - aggregator_reverse_reqs: Option, - aggregation_ordering: &mut Option, - mode: &AggregateMode, -) -> Result> { - let mut required_input_ordering = vec![]; - // Boolean shows that whether `required_input_ordering` stored comes from - // `aggregator_reqs` or `aggregator_reverse_reqs` - let mut reverse_req = false; - // If reverse aggregator is None, there is no way to run aggregators in reverse mode. Hence ignore it during analysis - let aggregator_requirements = - if let Some(aggregator_reverse_reqs) = aggregator_reverse_reqs { - // If existing ordering doesn't satisfy requirement, we should do calculations - // on naive requirement (by convention, otherwise the final plan will be unintuitive), - // even if reverse ordering is possible. - // Hence, while iterating consider naive requirement last, by this way - // we prioritize naive requirement over reverse requirement, when - // reverse requirement is not helpful with removing SortExec from the plan. - vec![(true, aggregator_reverse_reqs), (false, aggregator_reqs)] - } else { - vec![(false, aggregator_reqs)] - }; - for (is_reverse, aggregator_requirement) in aggregator_requirements.into_iter() { - if let Some(AggregationOrdering { - // If the mode is FullyOrdered or PartiallyOrdered (i.e. we are - // running with bounded memory, without breaking the pipeline), - // then we append the aggregator ordering requirement to the existing - // ordering. This way, we can still run with bounded memory. - mode: GroupByOrderMode::FullyOrdered | GroupByOrderMode::PartiallyOrdered, - order_indices, - .. - }) = aggregation_ordering - { - // Get the section of the input ordering that enables us to run in - // FullyOrdered or PartiallyOrdered modes: - let requirement_prefix = - if let Some(existing_ordering) = input.output_ordering() { - &existing_ordering[0..order_indices.len()] - } else { - &[] - }; - let mut requirement = - PhysicalSortRequirement::from_sort_exprs(requirement_prefix.iter()); - for req in aggregator_requirement { - // Final and FinalPartitioned modes don't enforce ordering - // requirements since order-sensitive aggregators handle such - // requirements during merging. - if mode.is_first_stage() - && requirement.iter().all(|item| req.expr.ne(&item.expr)) - { - requirement.push(req); - } - } - required_input_ordering = requirement; - } else if mode.is_first_stage() { - required_input_ordering = aggregator_requirement; - } - // Keep track of the direction from which required_input_ordering is constructed: - reverse_req = is_reverse; - // If all the order-sensitive aggregate functions are reversible (e.g. all the - // order-sensitive aggregators are either FIRST_VALUE or LAST_VALUE), then we can - // run aggregate expressions either in the given required ordering, (i.e. finest - // requirement that satisfies every aggregate function requirement) or its reverse - // (opposite) direction. We analyze these two possibilities, and use the version that - // satisfies existing ordering. This enables us to avoid an extra sort step in the final - // plan. If neither version satisfies the existing ordering, we use the given ordering - // requirement. In short, if running aggregators in reverse order help us to avoid a - // sorting step, we do so. Otherwise, we use the aggregators as is. - let existing_ordering = input.output_ordering().unwrap_or(&[]); - if ordering_satisfy_requirement_concrete( - existing_ordering, - &required_input_ordering, - || input.equivalence_properties(), - || input.ordering_equivalence_properties(), - ) { - break; - } - } - // If `required_input_ordering` is constructed using the reverse requirement, we - // should reverse each `aggr_expr` in order to correctly calculate their results - // in reverse order. - if reverse_req { - aggr_exprs - .iter_mut() - .zip(order_by_exprs.iter_mut()) - .map(|(aggr_expr, ob_expr)| { - if is_order_sensitive(aggr_expr) { - if let Some(reverse) = aggr_expr.reverse_expr() { - *aggr_expr = reverse; - *ob_expr = ob_expr.as_ref().map(|obs| reverse_order_bys(obs)); - } else { - return plan_err!( - "Aggregate expression should have a reverse expression" - ); - } - } - Ok(()) - }) - .collect::>>()?; - } - Ok((!required_input_ordering.is_empty()).then_some(required_input_ordering)) -} - -/// Check whether group by expression contains all of the expression inside `requirement` -// As an example Group By (c,b,a) contains all of the expressions in the `requirement`: (a ASC, b DESC) -fn group_by_contains_all_requirements( - group_by: &PhysicalGroupBy, - requirement: &LexOrdering, -) -> bool { - let physical_exprs = group_by - .expr() - .iter() - .map(|(expr, _alias)| expr.clone()) - .collect::>(); - // When we have multiple groups (grouping set) - // since group by may be calculated on the subset of the group_by.expr() - // it is not guaranteed to have all of the requirements among group by expressions. - // Hence do the analysis: whether group by contains all requirements in the single group case. - group_by.is_single() - && requirement - .iter() - .all(|req| physical_exprs_contains(&physical_exprs, &req.expr)) + required_input_ordering: Option, + /// Describes how the input is ordered relative to the group by columns + input_order_mode: InputOrderMode, + /// Describe how the output is ordered + output_ordering: Option, } impl AggregateExec { @@ -607,10 +278,8 @@ impl AggregateExec { pub fn try_new( mode: AggregateMode, group_by: PhysicalGroupBy, - mut aggr_expr: Vec>, + aggr_expr: Vec>, filter_expr: Vec>>, - // Ordering requirement of each aggregate expression - mut order_by_expr: Vec>, input: Arc, input_schema: SchemaRef, ) -> Result { @@ -623,94 +292,94 @@ impl AggregateExec { )?; let schema = Arc::new(schema); - // Reset ordering requirement to `None` if aggregator is not order-sensitive - order_by_expr = aggr_expr + AggregateExec::try_new_with_schema( + mode, + group_by, + aggr_expr, + filter_expr, + input, + input_schema, + schema, + ) + } + + /// Create a new hash aggregate execution plan with the given schema. + /// This constructor isn't part of the public API, it is used internally + /// by Datafusion to enforce schema consistency during when re-creating + /// `AggregateExec`s inside optimization rules. Schema field names of an + /// `AggregateExec` depends on the names of aggregate expressions. Since + /// a rule may re-write aggregate expressions (e.g. reverse them) during + /// initialization, field names may change inadvertently if one re-creates + /// the schema in such cases. + #[allow(clippy::too_many_arguments)] + fn try_new_with_schema( + mode: AggregateMode, + group_by: PhysicalGroupBy, + mut aggr_expr: Vec>, + filter_expr: Vec>>, + input: Arc, + input_schema: SchemaRef, + schema: SchemaRef, + ) -> Result { + let input_eq_properties = input.equivalence_properties(); + // Get GROUP BY expressions: + let groupby_exprs = group_by.input_exprs(); + // If existing ordering satisfies a prefix of the GROUP BY expressions, + // prefix requirements with this section. In this case, aggregation will + // work more efficiently. + let indices = get_ordered_partition_by_indices(&groupby_exprs, &input); + let mut new_requirement = indices .iter() - .zip(order_by_expr) - .map(|(aggr_expr, fn_reqs)| { - // If - // - aggregation function is order-sensitive and - // - aggregation is performing a "first stage" calculation, and - // - at least one of the aggregate function requirement is not inside group by expression - // keep the ordering requirement as is; otherwise ignore the ordering requirement. - // In non-first stage modes, we accumulate data (using `merge_batch`) - // from different partitions (i.e. merge partial results). During - // this merge, we consider the ordering of each partial result. - // Hence, we do not need to use the ordering requirement in such - // modes as long as partial results are generated with the - // correct ordering. - fn_reqs.filter(|req| { - is_order_sensitive(aggr_expr) - && mode.is_first_stage() - && !group_by_contains_all_requirements(&group_by, req) - }) + .map(|&idx| PhysicalSortRequirement { + expr: groupby_exprs[idx].clone(), + options: None, }) .collect::>(); - let mut aggregator_reverse_reqs = None; - // Currently we support order-sensitive aggregation only in `Single` mode. - // For `Final` and `FinalPartitioned` modes, we cannot guarantee they will receive - // data according to ordering requirements. As long as we cannot produce correct result - // in `Final` mode, it is not important to produce correct result in `Partial` mode. - // We only support `Single` mode, where we are sure that output produced is final, and it - // is produced in a single step. - - let requirement = get_finest_requirement( - &mut aggr_expr, - &mut order_by_expr, - || input.equivalence_properties(), - || input.ordering_equivalence_properties(), - )?; - let aggregator_requirement = requirement - .as_ref() - .map(|exprs| PhysicalSortRequirement::from_sort_exprs(exprs.iter())); - let aggregator_reqs = aggregator_requirement.unwrap_or(vec![]); - // If all aggregate expressions are reversible, also consider reverse - // requirement(s). The reason is that existing ordering may satisfy the - // given requirement or its reverse. By considering both, we can generate better plans. - if aggr_expr - .iter() - .all(|expr| !is_order_sensitive(expr) || expr.reverse_expr().is_some()) - { - aggregator_reverse_reqs = requirement.map(|reqs| { - PhysicalSortRequirement::from_sort_exprs(reverse_order_bys(&reqs).iter()) - }); - } - // construct a map from the input columns to the output columns of the Aggregation - let mut columns_map: HashMap> = HashMap::new(); - for (expression, name) in group_by.expr.iter() { - if let Some(column) = expression.as_any().downcast_ref::() { - let new_col_idx = schema.index_of(name)?; - let entry = columns_map.entry(column.clone()).or_insert_with(Vec::new); - entry.push(Column::new(name, new_col_idx)); - }; - } - - let mut aggregation_ordering = calc_aggregation_ordering(&input, &group_by); - let required_input_ordering = calc_required_input_ordering( - &input, + let req = get_aggregate_exprs_requirement( + &new_requirement, &mut aggr_expr, - &mut order_by_expr, - aggregator_reqs, - aggregator_reverse_reqs, - &mut aggregation_ordering, + &group_by, + &input_eq_properties, &mode, )?; + new_requirement.extend(req); + new_requirement = collapse_lex_req(new_requirement); + + let input_order_mode = + if indices.len() == groupby_exprs.len() && !indices.is_empty() { + InputOrderMode::Sorted + } else if !indices.is_empty() { + InputOrderMode::PartiallySorted(indices) + } else { + InputOrderMode::Linear + }; + + // construct a map from the input expression to the output expression of the Aggregation group by + let projection_mapping = + ProjectionMapping::try_new(&group_by.expr, &input.schema())?; + + let required_input_ordering = + (!new_requirement.is_empty()).then_some(new_requirement); + + let aggregate_eqs = + input_eq_properties.project(&projection_mapping, schema.clone()); + let output_ordering = aggregate_eqs.oeq_class().output_ordering(); Ok(AggregateExec { mode, group_by, aggr_expr, filter_expr, - order_by_expr, input, schema, input_schema, - columns_map, + projection_mapping, metrics: ExecutionPlanMetricsSet::new(), - aggregation_ordering, required_input_ordering, limit: None, + input_order_mode, + output_ordering, }) } @@ -731,7 +400,7 @@ impl AggregateExec { /// Grouping expressions as they occur in the output schema pub fn output_group_expr(&self) -> Vec> { - output_group_expr_helper(&self.group_by) + self.group_by.output_exprs() } /// Aggregate expressions @@ -744,11 +413,6 @@ impl AggregateExec { &self.filter_expr } - /// ORDER BY clause expression for each aggregate expression - pub fn order_by_expr(&self) -> &[Option] { - &self.order_by_expr - } - /// Input plan pub fn input(&self) -> &Arc { &self.input @@ -759,6 +423,11 @@ impl AggregateExec { self.input_schema.clone() } + /// number of rows soft limit of the AggregateExec + pub fn limit(&self) -> Option { + self.limit + } + fn execute_typed( &self, partition: usize, @@ -773,9 +442,11 @@ impl AggregateExec { // grouping by an expression that has a sort/limit upstream if let Some(limit) = self.limit { - return Ok(StreamType::GroupedPriorityQueue( - GroupedTopKAggregateStream::new(self, context, partition, limit)?, - )); + if !self.is_unordered_unfiltered_group_by_distinct() { + return Ok(StreamType::GroupedPriorityQueue( + GroupedTopKAggregateStream::new(self, context, partition, limit)?, + )); + } } // grouping by something else and we need to just materialize all results @@ -799,6 +470,39 @@ impl AggregateExec { pub fn group_by(&self) -> &PhysicalGroupBy { &self.group_by } + + /// true, if this Aggregate has a group-by with no required or explicit ordering, + /// no filtering and no aggregate expressions + /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule + /// on an AggregateExec. + pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool { + // ensure there is a group by + if self.group_by().is_empty() { + return false; + } + // ensure there are no aggregate expressions + if !self.aggr_expr().is_empty() { + return false; + } + // ensure there are no filters on aggregate expressions; the above check + // may preclude this case + if self.filter_expr().iter().any(|e| e.is_some()) { + return false; + } + // ensure there are no order by expressions + if self.aggr_expr().iter().any(|e| e.order_bys().is_some()) { + return false; + } + // ensure there is no output ordering; can this rule be relaxed? + if self.output_ordering().is_some() { + return false; + } + // ensure no ordering is required on the input + if self.required_input_ordering()[0].is_some() { + return false; + } + true + } } impl DisplayAs for AggregateExec { @@ -869,8 +573,8 @@ impl DisplayAs for AggregateExec { write!(f, ", lim=[{limit}]")?; } - if let Some(aggregation_ordering) = &self.aggregation_ordering { - write!(f, ", ordering_mode={:?}", aggregation_ordering.mode)?; + if self.input_order_mode != InputOrderMode::Linear { + write!(f, ", ordering_mode={:?}", self.input_order_mode)?; } } } @@ -890,29 +594,30 @@ impl ExecutionPlan for AggregateExec { /// Get the output partitioning of this plan fn output_partitioning(&self) -> Partitioning { - match &self.mode { - AggregateMode::Partial | AggregateMode::Single => { - // Partial and Single Aggregation will not change the output partitioning but need to respect the Alias - let input_partition = self.input.output_partitioning(); - match input_partition { - Partitioning::Hash(exprs, part) => { - let normalized_exprs = exprs - .into_iter() - .map(|expr| { - normalize_out_expr_with_columns_map( - expr, - &self.columns_map, - ) + let input_partition = self.input.output_partitioning(); + if self.mode.is_first_stage() { + // First stage aggregation will not change the output partitioning, + // but needs to respect aliases (e.g. mapping in the GROUP BY + // expression). + let input_eq_properties = self.input.equivalence_properties(); + // First stage Aggregation will not change the output partitioning but need to respect the Alias + let input_partition = self.input.output_partitioning(); + if let Partitioning::Hash(exprs, part) = input_partition { + let normalized_exprs = exprs + .into_iter() + .map(|expr| { + input_eq_properties + .project_expr(&expr, &self.projection_mapping) + .unwrap_or_else(|| { + Arc::new(UnKnownColumn::new(&expr.to_string())) }) - .collect::>(); - Partitioning::Hash(normalized_exprs, part) - } - _ => input_partition, - } + }) + .collect(); + return Partitioning::Hash(normalized_exprs, part); } - // Final Aggregation's output partitioning is the same as its real input - _ => self.input.output_partitioning(), } + // Final Aggregation's output partitioning is the same as its real input + input_partition } /// Specifies whether this plan generates an infinite stream of records. @@ -920,7 +625,7 @@ impl ExecutionPlan for AggregateExec { /// infinite, returns an error to indicate this. fn unbounded_output(&self, children: &[bool]) -> Result { if children[0] { - if self.aggregation_ordering.is_none() { + if self.input_order_mode == InputOrderMode::Linear { // Cannot run without breaking pipeline. plan_err!( "Aggregate Error: `GROUP BY` clauses with columns without ordering and GROUPING SETS are not supported for unbounded inputs." @@ -934,9 +639,7 @@ impl ExecutionPlan for AggregateExec { } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.aggregation_ordering - .as_ref() - .map(|item: &AggregationOrdering| item.ordering.as_slice()) + self.output_ordering.as_deref() } fn required_input_distribution(&self) -> Vec { @@ -953,18 +656,14 @@ impl ExecutionPlan for AggregateExec { } } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![self.required_input_ordering.clone()] } fn equivalence_properties(&self) -> EquivalenceProperties { - let mut new_properties = EquivalenceProperties::new(self.schema()); - project_equivalence_properties( - self.input.equivalence_properties(), - &self.columns_map, - &mut new_properties, - ); - new_properties + self.input + .equivalence_properties() + .project(&self.projection_mapping, self.schema()) } fn children(&self) -> Vec> { @@ -975,14 +674,15 @@ impl ExecutionPlan for AggregateExec { self: Arc, children: Vec>, ) -> Result> { - let mut me = AggregateExec::try_new( + let mut me = AggregateExec::try_new_with_schema( self.mode, self.group_by.clone(), self.aggr_expr.clone(), self.filter_expr.clone(), - self.order_by_expr.clone(), children[0].clone(), self.input_schema.clone(), + self.schema.clone(), + //self.original_schema.clone(), )?; me.limit = self.limit; Ok(Arc::new(me)) @@ -1001,28 +701,50 @@ impl ExecutionPlan for AggregateExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { // TODO stats: group expressions: // - once expressions will be able to compute their own stats, use it here // - case where we group by on a column for which with have the `distinct` stat // TODO stats: aggr expression: // - aggregations somtimes also preserve invariants such as min, max... + let column_statistics = Statistics::unknown_column(&self.schema()); match self.mode { AggregateMode::Final | AggregateMode::FinalPartitioned if self.group_by.expr.is_empty() => { - Statistics { - num_rows: Some(1), - is_exact: true, - ..Default::default() - } + Ok(Statistics { + num_rows: Precision::Exact(1), + column_statistics, + total_byte_size: Precision::Absent, + }) + } + _ => { + // When the input row count is 0 or 1, we can adopt that statistic keeping its reliability. + // When it is larger than 1, we degrade the precision since it may decrease after aggregation. + let num_rows = if let Some(value) = + self.input().statistics()?.num_rows.get_value() + { + if *value > 1 { + self.input().statistics()?.num_rows.to_inexact() + } else if *value == 0 { + // Aggregation on an empty table creates a null row. + self.input() + .statistics()? + .num_rows + .add(&Precision::Exact(1)) + } else { + // num_rows = 1 case + self.input().statistics()?.num_rows + } + } else { + Precision::Absent + }; + Ok(Statistics { + num_rows, + column_statistics, + total_byte_size: Precision::Absent, + }) } - _ => Statistics { - // the output row count is surely not larger than its input row count - num_rows: self.input.statistics().num_rows, - is_exact: false, - ..Default::default() - }, } } } @@ -1072,6 +794,176 @@ fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { Arc::new(Schema::new(group_fields)) } +/// Determines the lexical ordering requirement for an aggregate expression. +/// +/// # Parameters +/// +/// - `aggr_expr`: A reference to an `Arc` representing the +/// aggregate expression. +/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the +/// physical GROUP BY expression. +/// - `agg_mode`: A reference to an `AggregateMode` instance representing the +/// mode of aggregation. +/// +/// # Returns +/// +/// A `LexOrdering` instance indicating the lexical ordering requirement for +/// the aggregate expression. +fn get_aggregate_expr_req( + aggr_expr: &Arc, + group_by: &PhysicalGroupBy, + agg_mode: &AggregateMode, +) -> LexOrdering { + // If the aggregation function is not order sensitive, or the aggregation + // is performing a "second stage" calculation, or all aggregate function + // requirements are inside the GROUP BY expression, then ignore the ordering + // requirement. + if !is_order_sensitive(aggr_expr) || !agg_mode.is_first_stage() { + return vec![]; + } + + let mut req = aggr_expr.order_bys().unwrap_or_default().to_vec(); + + // In non-first stage modes, we accumulate data (using `merge_batch`) from + // different partitions (i.e. merge partial results). During this merge, we + // consider the ordering of each partial result. Hence, we do not need to + // use the ordering requirement in such modes as long as partial results are + // generated with the correct ordering. + if group_by.is_single() { + // Remove all orderings that occur in the group by. These requirements + // will definitely be satisfied -- Each group by expression will have + // distinct values per group, hence all requirements are satisfied. + let physical_exprs = group_by.input_exprs(); + req.retain(|sort_expr| { + !physical_exprs_contains(&physical_exprs, &sort_expr.expr) + }); + } + req +} + +/// Computes the finer ordering for between given existing ordering requirement +/// of aggregate expression. +/// +/// # Parameters +/// +/// * `existing_req` - The existing lexical ordering that needs refinement. +/// * `aggr_expr` - A reference to an aggregate expression trait object. +/// * `group_by` - Information about the physical grouping (e.g group by expression). +/// * `eq_properties` - Equivalence properties relevant to the computation. +/// * `agg_mode` - The mode of aggregation (e.g., Partial, Final, etc.). +/// +/// # Returns +/// +/// An `Option` representing the computed finer lexical ordering, +/// or `None` if there is no finer ordering; e.g. the existing requirement and +/// the aggregator requirement is incompatible. +fn finer_ordering( + existing_req: &LexOrdering, + aggr_expr: &Arc, + group_by: &PhysicalGroupBy, + eq_properties: &EquivalenceProperties, + agg_mode: &AggregateMode, +) -> Option { + let aggr_req = get_aggregate_expr_req(aggr_expr, group_by, agg_mode); + eq_properties.get_finer_ordering(existing_req, &aggr_req) +} + +/// Concatenates the given slices. +fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { + [lhs, rhs].concat() +} + +/// Get the common requirement that satisfies all the aggregate expressions. +/// +/// # Parameters +/// +/// - `aggr_exprs`: A slice of `Arc` containing all the +/// aggregate expressions. +/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the +/// physical GROUP BY expression. +/// - `eq_properties`: A reference to an `EquivalenceProperties` instance +/// representing equivalence properties for ordering. +/// - `agg_mode`: A reference to an `AggregateMode` instance representing the +/// mode of aggregation. +/// +/// # Returns +/// +/// A `LexRequirement` instance, which is the requirement that satisfies all the +/// aggregate requirements. Returns an error in case of conflicting requirements. +fn get_aggregate_exprs_requirement( + prefix_requirement: &[PhysicalSortRequirement], + aggr_exprs: &mut [Arc], + group_by: &PhysicalGroupBy, + eq_properties: &EquivalenceProperties, + agg_mode: &AggregateMode, +) -> Result { + let mut requirement = vec![]; + for aggr_expr in aggr_exprs.iter_mut() { + let aggr_req = aggr_expr.order_bys().unwrap_or(&[]); + let reverse_aggr_req = reverse_order_bys(aggr_req); + let aggr_req = PhysicalSortRequirement::from_sort_exprs(aggr_req); + let reverse_aggr_req = + PhysicalSortRequirement::from_sort_exprs(&reverse_aggr_req); + if let Some(first_value) = aggr_expr.as_any().downcast_ref::() { + let mut first_value = first_value.clone(); + if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &aggr_req, + )) { + first_value = first_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(first_value) as _; + } else if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &reverse_aggr_req, + )) { + // Converting to LAST_VALUE enables more efficient execution + // given the existing ordering: + let mut last_value = first_value.convert_to_last(); + last_value = last_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(last_value) as _; + } else { + // Requirement is not satisfied with existing ordering. + first_value = first_value.with_requirement_satisfied(false); + *aggr_expr = Arc::new(first_value) as _; + } + } else if let Some(last_value) = aggr_expr.as_any().downcast_ref::() { + let mut last_value = last_value.clone(); + if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &aggr_req, + )) { + last_value = last_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(last_value) as _; + } else if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &reverse_aggr_req, + )) { + // Converting to FIRST_VALUE enables more efficient execution + // given the existing ordering: + let mut first_value = last_value.convert_to_first(); + first_value = first_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(first_value) as _; + } else { + // Requirement is not satisfied with existing ordering. + last_value = last_value.with_requirement_satisfied(false); + *aggr_expr = Arc::new(last_value) as _; + } + } else if let Some(finer_ordering) = + finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) + { + requirement = finer_ordering; + } else { + // If neither of the requirements satisfy the other, this means + // requirements are conflicting. Currently, we do not support + // conflicting requirements. + return not_impl_err!( + "Conflicting ordering requirements in aggregate functions is not supported" + ); + } + } + Ok(PhysicalSortRequirement::from_sort_exprs(&requirement)) +} + /// returns physical expressions for arguments to evaluate against a batch /// The expressions are different depending on `mode`: /// * Partial: AggregateExpr::expressions @@ -1087,33 +979,27 @@ fn aggregate_expressions( | AggregateMode::SinglePartitioned => Ok(aggr_expr .iter() .map(|agg| { - let mut result = agg.expressions().clone(); - // In partial mode, append ordering requirements to expressions' results. - // Ordering requirements are used by subsequent executors to satisfy the required - // ordering for `AggregateMode::FinalPartitioned`/`AggregateMode::Final` modes. - if matches!(mode, AggregateMode::Partial) { - if let Some(ordering_req) = agg.order_bys() { - let ordering_exprs = ordering_req - .iter() - .map(|item| item.expr.clone()) - .collect::>(); - result.extend(ordering_exprs); - } + let mut result = agg.expressions(); + // Append ordering requirements to expressions' results. This + // way order sensitive aggregators can satisfy requirement + // themselves. + if let Some(ordering_req) = agg.order_bys() { + result.extend(ordering_req.iter().map(|item| item.expr.clone())); } result }) .collect()), - // in this mode, we build the merge expressions of the aggregation + // In this mode, we build the merge expressions of the aggregation. AggregateMode::Final | AggregateMode::FinalPartitioned => { let mut col_idx_base = col_idx_base; - Ok(aggr_expr + aggr_expr .iter() .map(|agg| { let exprs = merge_expressions(col_idx_base, agg)?; col_idx_base += exprs.len(); Ok(exprs) }) - .collect::>>()?) + .collect() } } } @@ -1126,14 +1012,13 @@ fn merge_expressions( index_base: usize, expr: &Arc, ) -> Result>> { - Ok(expr - .state_fields()? - .iter() - .enumerate() - .map(|(idx, f)| { - Arc::new(Column::new(f.name(), index_base + idx)) as Arc - }) - .collect::>()) + expr.state_fields().map(|fields| { + fields + .iter() + .enumerate() + .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _) + .collect() + }) } pub(crate) type AccumulatorItem = Box; @@ -1144,7 +1029,7 @@ fn create_accumulators( aggr_expr .iter() .map(|expr| expr.create_accumulator()) - .collect::>>() + .collect() } /// returns a vector of ArrayRefs, where each entry corresponds to either the @@ -1155,27 +1040,28 @@ fn finalize_aggregation( ) -> Result> { match mode { AggregateMode::Partial => { - // build the vector of states - let a = accumulators + // Build the vector of states + accumulators .iter() - .map(|accumulator| accumulator.state()) - .map(|value| { - value.map(|e| { - e.iter().map(|v| v.to_array()).collect::>() + .map(|accumulator| { + accumulator.state().and_then(|e| { + e.iter() + .map(|v| v.to_array()) + .collect::>>() }) }) - .collect::>>()?; - Ok(a.iter().flatten().cloned().collect::>()) + .flatten_ok() + .collect() } AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::Single | AggregateMode::SinglePartitioned => { - // merge the state to the final value + // Merge the state to the final value accumulators .iter() - .map(|accumulator| accumulator.evaluate().map(|v| v.to_array())) - .collect::>>() + .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) + .collect() } } } @@ -1186,9 +1072,11 @@ fn evaluate( batch: &RecordBatch, ) -> Result> { expr.iter() - .map(|expr| expr.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) - .collect::>>() + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect() } /// Evaluates expressions against a record batch. @@ -1196,9 +1084,7 @@ pub(crate) fn evaluate_many( expr: &[Vec>], batch: &RecordBatch, ) -> Result>> { - expr.iter() - .map(|expr| evaluate(expr, batch)) - .collect::>>() + expr.iter().map(|expr| evaluate(expr, batch)).collect() } fn evaluate_optional( @@ -1208,11 +1094,13 @@ fn evaluate_optional( expr.iter() .map(|expr| { expr.as_ref() - .map(|expr| expr.evaluate(batch)) + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .transpose() - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) }) - .collect::>>() + .collect() } /// Evaluate a group by expression against a `RecordBatch` @@ -1234,7 +1122,7 @@ pub(crate) fn evaluate_group_by( .iter() .map(|(expr, _)| { let value = expr.evaluate(batch)?; - Ok(value.into_array(batch.num_rows())) + value.into_array(batch.num_rows()) }) .collect::>>()?; @@ -1243,7 +1131,7 @@ pub(crate) fn evaluate_group_by( .iter() .map(|(expr, _)| { let value = expr.evaluate(batch)?; - Ok(value.into_array(batch.num_rows())) + value.into_array(batch.num_rows()) }) .collect::>>()?; @@ -1268,19 +1156,19 @@ pub(crate) fn evaluate_group_by( #[cfg(test)] mod tests { + use std::any::Any; + use std::sync::Arc; + use std::task::{Context, Poll}; + use super::*; - use crate::aggregates::GroupByOrderMode::{FullyOrdered, PartiallyOrdered}; - use crate::aggregates::{ - get_finest_requirement, get_working_mode, AggregateExec, AggregateMode, - PhysicalGroupBy, - }; + use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::coalesce_batches::CoalesceBatchesExec; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common; use crate::expressions::{col, Avg}; use crate::memory::MemoryExec; + use crate::test::assert_is_pending; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; - use crate::test::{assert_is_pending, mem_exec}; use crate::{ DisplayAs, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, @@ -1294,20 +1182,17 @@ mod tests { assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError, Result, ScalarValue, }; + use datafusion_execution::config::SessionConfig; + use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_physical_expr::expressions::{ - lit, ApproxDistinct, Column, Count, FirstValue, LastValue, Median, + lit, ApproxDistinct, Count, FirstValue, LastValue, Median, OrderSensitiveArrayAgg, }; use datafusion_physical_expr::{ - AggregateExpr, EquivalenceProperties, OrderingEquivalenceProperties, - PhysicalExpr, PhysicalSortExpr, + reverse_order_bys, AggregateExpr, EquivalenceProperties, PhysicalExpr, + PhysicalSortExpr, }; - use std::any::Any; - use std::sync::Arc; - use std::task::{Context, Poll}; - - use datafusion_execution::config::SessionConfig; use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -1322,80 +1207,6 @@ mod tests { Ok(schema) } - /// make PhysicalSortExpr with default options - fn sort_expr(name: &str, schema: &Schema) -> PhysicalSortExpr { - sort_expr_options(name, schema, SortOptions::default()) - } - - /// PhysicalSortExpr with specified options - fn sort_expr_options( - name: &str, - schema: &Schema, - options: SortOptions, - ) -> PhysicalSortExpr { - PhysicalSortExpr { - expr: col(name, schema).unwrap(), - options, - } - } - - #[tokio::test] - async fn test_get_working_mode() -> Result<()> { - let test_schema = create_test_schema()?; - // Source is sorted by a ASC NULLS FIRST, b ASC NULLS FIRST, c ASC NULLS FIRST - // Column d, e is not ordered. - let sort_exprs = vec![ - sort_expr("a", &test_schema), - sort_expr("b", &test_schema), - sort_expr("c", &test_schema), - ]; - let input = mem_exec(1).with_sort_information(vec![sort_exprs]); - let input = Arc::new(input) as _; - - // test cases consists of vector of tuples. Where each tuple represents a single test case. - // First field in the tuple is Vec where each element in the vector represents GROUP BY columns - // For instance `vec!["a", "b"]` corresponds to GROUP BY a, b - // Second field in the tuple is Option, which corresponds to expected algorithm mode. - // None represents that existing ordering is not sufficient to run executor with any one of the algorithms - // (We need to add SortExec to be able to run it). - // Some(GroupByOrderMode) represents, we can run algorithm with existing ordering; and algorithm should work in - // GroupByOrderMode. - let test_cases = vec![ - (vec!["a"], Some((FullyOrdered, vec![0]))), - (vec!["b"], None), - (vec!["c"], None), - (vec!["b", "a"], Some((FullyOrdered, vec![1, 0]))), - (vec!["c", "b"], None), - (vec!["c", "a"], Some((PartiallyOrdered, vec![1]))), - (vec!["c", "b", "a"], Some((FullyOrdered, vec![2, 1, 0]))), - (vec!["d", "a"], Some((PartiallyOrdered, vec![1]))), - (vec!["d", "b"], None), - (vec!["d", "c"], None), - (vec!["d", "b", "a"], Some((PartiallyOrdered, vec![2, 1]))), - (vec!["d", "c", "b"], None), - (vec!["d", "c", "a"], Some((PartiallyOrdered, vec![2]))), - ( - vec!["d", "c", "b", "a"], - Some((PartiallyOrdered, vec![3, 2, 1])), - ), - ]; - for (case_idx, test_case) in test_cases.iter().enumerate() { - let (group_by_columns, expected) = &test_case; - let mut group_by_exprs = vec![]; - for col_name in group_by_columns { - group_by_exprs.push((col(col_name, &test_schema)?, col_name.to_string())); - } - let group_bys = PhysicalGroupBy::new_single(group_by_exprs); - let res = get_working_mode(&input, &group_bys); - assert_eq!( - res, *expected, - "Unexpected result for in unbounded test case#: {case_idx:?}, case: {test_case:?}" - ); - } - - Ok(()) - } - /// some mock data to aggregates fn some_data() -> (Arc, Vec) { // define a schema. @@ -1482,8 +1293,11 @@ mod tests { fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc { let session_config = SessionConfig::new().with_batch_size(batch_size); let runtime = Arc::new( - RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(max_memory, 1.0)) - .unwrap(), + RuntimeEnv::new( + RuntimeConfig::default() + .with_memory_pool(Arc::new(FairSpillPool::new(max_memory))), + ) + .unwrap(), ); let task_ctx = TaskContext::default() .with_session_config(session_config) @@ -1530,7 +1344,6 @@ mod tests { grouping_set.clone(), aggregates.clone(), vec![None], - vec![None], input, input_schema.clone(), )?); @@ -1609,7 +1422,6 @@ mod tests { final_grouping_set, aggregates, vec![None], - vec![None], merge, input_schema, )?); @@ -1675,7 +1487,6 @@ mod tests { grouping_set.clone(), aggregates.clone(), vec![None], - vec![None], input, input_schema.clone(), )?); @@ -1723,7 +1534,6 @@ mod tests { final_grouping_set, aggregates, vec![None], - vec![None], merge, input_schema, )?); @@ -1822,9 +1632,13 @@ mod tests { Ok(Box::pin(stream)) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { let (_, batches) = some_data(); - common::compute_record_batch_statistics(&[batches], &self.schema(), None) + Ok(common::compute_record_batch_statistics( + &[batches], + &self.schema(), + None, + )) } } @@ -1868,7 +1682,7 @@ mod tests { } } - //// Tests //// + //--- Tests ---// #[tokio::test] async fn aggregate_source_not_yielding() -> Result<()> { @@ -1986,7 +1800,6 @@ mod tests { groups, aggregates, vec![None; 3], - vec![None; 3], input.clone(), input_schema.clone(), )?); @@ -2042,7 +1855,6 @@ mod tests { groups.clone(), aggregates.clone(), vec![None], - vec![None], blocking_exec, schema, )?); @@ -2081,7 +1893,6 @@ mod tests { groups, aggregates.clone(), vec![None], - vec![None], blocking_exec, schema, )?); @@ -2132,7 +1943,7 @@ mod tests { spill: bool, ) -> Result<()> { let task_ctx = if spill { - new_spill_ctx(2, 2812) + new_spill_ctx(2, 2886) } else { Arc::new(TaskContext::default()) }; @@ -2183,7 +1994,6 @@ mod tests { groups.clone(), aggregates.clone(), vec![None], - vec![Some(ordering_req.clone())], memory_exec, schema.clone(), )?); @@ -2199,7 +2009,6 @@ mod tests { groups, aggregates.clone(), vec![None], - vec![Some(ordering_req)], coalesce, schema, )?) as Arc; @@ -2240,74 +2049,130 @@ mod tests { descending: false, nulls_first: false, }; - // This is the reverse requirement of options1 - let options2 = SortOptions { - descending: true, - nulls_first: true, - }; - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - let col_a = Column::new("a", 0); - let col_b = Column::new("b", 1); - let col_c = Column::new("c", 2); - let col_d = Column::new("d", 3); - eq_properties.add_equal_conditions((&col_a, &col_b)); - let mut ordering_eq_properties = OrderingEquivalenceProperties::new(test_schema); - ordering_eq_properties.add_equal_conditions(( - &vec![PhysicalSortExpr { - expr: Arc::new(col_a.clone()) as _, - options: options1, - }], - &vec![PhysicalSortExpr { - expr: Arc::new(col_c.clone()) as _, - options: options2, - }], - )); - let mut order_by_exprs = vec![ + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let mut eq_properties = EquivalenceProperties::new(test_schema); + // Columns a and b are equal. + eq_properties.add_equal_conditions(col_a, col_b); + // Aggregate requirements are + // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively + let order_by_exprs = vec![ None, Some(vec![PhysicalSortExpr { - expr: Arc::new(col_a.clone()), + expr: col_a.clone(), options: options1, }]), - Some(vec![PhysicalSortExpr { - expr: Arc::new(col_b.clone()), - options: options1, - }]), - Some(vec![PhysicalSortExpr { - expr: Arc::new(col_c), - options: options2, - }]), Some(vec![ PhysicalSortExpr { - expr: Arc::new(col_a.clone()), + expr: col_a.clone(), options: options1, }, PhysicalSortExpr { - expr: Arc::new(col_d), + expr: col_b.clone(), + options: options1, + }, + PhysicalSortExpr { + expr: col_c.clone(), options: options1, }, ]), - // Since aggregate expression is reversible (FirstValue), we should be able to resolve below - // contradictory requirement by reversing it. - Some(vec![PhysicalSortExpr { - expr: Arc::new(col_b.clone()), - options: options2, - }]), + Some(vec![ + PhysicalSortExpr { + expr: col_a.clone(), + options: options1, + }, + PhysicalSortExpr { + expr: col_b.clone(), + options: options1, + }, + ]), + ]; + let common_requirement = vec![ + PhysicalSortExpr { + expr: col_a.clone(), + options: options1, + }, + PhysicalSortExpr { + expr: col_c.clone(), + options: options1, + }, ]; - let aggr_expr = Arc::new(FirstValue::new( - Arc::new(col_a.clone()), - "first1", - DataType::Int32, - vec![], - vec![], - )) as _; - let mut aggr_exprs = vec![aggr_expr; order_by_exprs.len()]; - let res = get_finest_requirement( + let mut aggr_exprs = order_by_exprs + .into_iter() + .map(|order_by_expr| { + Arc::new(OrderSensitiveArrayAgg::new( + col_a.clone(), + "array_agg", + DataType::Int32, + false, + vec![], + order_by_expr.unwrap_or_default(), + )) as _ + }) + .collect::>(); + let group_by = PhysicalGroupBy::new_single(vec![]); + let res = get_aggregate_exprs_requirement( + &[], &mut aggr_exprs, - &mut order_by_exprs, - || eq_properties.clone(), - || ordering_eq_properties.clone(), + &group_by, + &eq_properties, + &AggregateMode::Partial, )?; - assert_eq!(res, order_by_exprs[4]); + let res = PhysicalSortRequirement::to_sort_exprs(res); + assert_eq!(res, common_requirement); + Ok(()) + } + + #[test] + fn test_agg_exec_same_schema() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, true), + Field::new("b", DataType::Float32, true), + ])); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let sort_expr = vec![PhysicalSortExpr { + expr: col_b.clone(), + options: option_desc, + }]; + let sort_expr_reverse = reverse_order_bys(&sort_expr); + let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); + + let aggregates: Vec> = vec![ + Arc::new(FirstValue::new( + col_b.clone(), + "FIRST_VALUE(b)".to_string(), + DataType::Float64, + sort_expr_reverse.clone(), + vec![DataType::Float64], + )), + Arc::new(LastValue::new( + col_b.clone(), + "LAST_VALUE(b)".to_string(), + DataType::Float64, + sort_expr.clone(), + vec![DataType::Float64], + )), + ]; + let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups, + aggregates.clone(), + vec![None, None], + blocking_exec.clone(), + schema, + )?); + let new_agg = aggregate_exec + .clone() + .with_new_children(vec![blocking_exec])?; + assert_eq!(new_agg.schema(), aggregate_exec.schema()); Ok(()) } } diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index 32c0bbc78a5d..90eb488a2ead 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -217,8 +217,10 @@ fn aggregate_batch( // 1.3 let values = &expr .iter() - .map(|e| e.evaluate(&batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(&batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect::>>()?; // 1.4 diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index f0b49872b1c5..b258b97a9e84 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -18,13 +18,12 @@ use arrow_array::ArrayRef; use arrow_schema::Schema; use datafusion_common::Result; -use datafusion_physical_expr::EmitTo; - -use super::{AggregationOrdering, GroupByOrderMode}; +use datafusion_physical_expr::{EmitTo, PhysicalSortExpr}; mod full; mod partial; +use crate::InputOrderMode; pub(crate) use full::GroupOrderingFull; pub(crate) use partial::GroupOrderingPartial; @@ -43,24 +42,17 @@ impl GroupOrdering { /// Create a `GroupOrdering` for the the specified ordering pub fn try_new( input_schema: &Schema, - ordering: &AggregationOrdering, + mode: &InputOrderMode, + ordering: &[PhysicalSortExpr], ) -> Result { - let AggregationOrdering { - mode, - order_indices, - ordering, - } = ordering; - - Ok(match mode { - GroupByOrderMode::PartiallyOrdered => { - let partial = - GroupOrderingPartial::try_new(input_schema, order_indices, ordering)?; - GroupOrdering::Partial(partial) + match mode { + InputOrderMode::Linear => Ok(GroupOrdering::None), + InputOrderMode::PartiallySorted(order_indices) => { + GroupOrderingPartial::try_new(input_schema, order_indices, ordering) + .map(GroupOrdering::Partial) } - GroupByOrderMode::FullyOrdered => { - GroupOrdering::Full(GroupOrderingFull::new()) - } - }) + InputOrderMode::Sorted => Ok(GroupOrdering::Full(GroupOrderingFull::new())), + } } // How many groups be emitted, or None if no data can be emitted diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index d773533ad6a3..6a0c02f5caf3 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -17,17 +17,10 @@ //! Hash aggregation -use datafusion_physical_expr::{ - AggregateExpr, EmitTo, GroupsAccumulator, GroupsAccumulatorAdapter, PhysicalSortExpr, -}; -use log::debug; use std::sync::Arc; use std::task::{Context, Poll}; use std::vec; -use futures::ready; -use futures::stream::{Stream, StreamExt}; - use crate::aggregates::group_values::{new_group_values, GroupValues}; use crate::aggregates::order::GroupOrderingFull; use crate::aggregates::{ @@ -39,8 +32,9 @@ use crate::metrics::{BaselineMetrics, RecordOutput}; use crate::sorts::sort::{read_spill_as_stream, sort_batch}; use crate::sorts::streaming_merge; use crate::stream::RecordBatchStreamAdapter; -use crate::{aggregates, PhysicalExpr}; +use crate::{aggregates, ExecutionPlan, PhysicalExpr}; use crate::{RecordBatchStream, SendableRecordBatchStream}; + use arrow::array::*; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use arrow_schema::SortOptions; @@ -50,7 +44,14 @@ use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; -use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::{ + AggregateExpr, EmitTo, GroupsAccumulator, GroupsAccumulatorAdapter, PhysicalSortExpr, +}; + +use futures::ready; +use futures::stream::{Stream, StreamExt}; +use log::debug; #[derive(Debug, Clone)] /// This object tracks the aggregation phase (input/output) @@ -266,6 +267,12 @@ pub(crate) struct GroupedHashAggregateStream { /// The spill state object spill_state: SpillState, + + /// Optional soft limit on the number of `group_values` in a batch + /// If the number of `group_values` in a single batch exceeds this value, + /// the `GroupedHashAggregateStream` operation immediately switches to + /// output mode and emits all groups. + group_values_soft_limit: Option, } impl GroupedHashAggregateStream { @@ -321,24 +328,25 @@ impl GroupedHashAggregateStream { let spill_expr = group_schema .fields .into_iter() - .map(|field| PhysicalSortExpr { - expr: col(field.name(), &group_schema).unwrap(), + .enumerate() + .map(|(idx, field)| PhysicalSortExpr { + expr: Arc::new(Column::new(field.name().as_str(), idx)) as _, options: SortOptions::default(), }) .collect(); let name = format!("GroupedHashAggregateStream[{partition}]"); - let reservation = MemoryConsumer::new(name).register(context.memory_pool()); - - let group_ordering = agg - .aggregation_ordering - .as_ref() - .map(|aggregation_ordering| { - GroupOrdering::try_new(&group_schema, aggregation_ordering) - }) - // return error if any - .transpose()? - .unwrap_or(GroupOrdering::None); + let reservation = MemoryConsumer::new(name) + .with_can_spill(true) + .register(context.memory_pool()); + let (ordering, _) = agg + .equivalence_properties() + .find_longest_permutation(&agg_group_by.output_exprs()); + let group_ordering = GroupOrdering::try_new( + &group_schema, + &agg.input_order_mode, + ordering.as_slice(), + )?; let group_values = new_group_values(group_schema)?; timer.done(); @@ -372,6 +380,7 @@ impl GroupedHashAggregateStream { input_done: false, runtime: context.runtime_env(), spill_state, + group_values_soft_limit: agg.limit, }) } } @@ -416,9 +425,8 @@ impl Stream for GroupedHashAggregateStream { let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); loop { - let exec_state = self.exec_state.clone(); - match exec_state { - ExecutionState::ReadingInput => { + match &self.exec_state { + ExecutionState::ReadingInput => 'reading_input: { match ready!(self.input.poll_next_unpin(cx)) { // new batch to aggregate Some(Ok(batch)) => { @@ -433,9 +441,21 @@ impl Stream for GroupedHashAggregateStream { // otherwise keep consuming input assert!(!self.input_done); + // If the number of group values equals or exceeds the soft limit, + // emit all groups and switch to producing output + if self.hit_soft_group_limit() { + timer.done(); + extract_ok!(self.set_input_done_and_produce_output()); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + if let Some(to_emit) = self.group_ordering.emit_to() { let batch = extract_ok!(self.emit(to_emit, false)); self.exec_state = ExecutionState::ProducingOutput(batch); + timer.done(); + // make sure the exec_state just set is not overwritten below + break 'reading_input; } extract_ok!(self.emit_early_if_necessary()); @@ -448,37 +468,31 @@ impl Stream for GroupedHashAggregateStream { } None => { // inner is done, emit all rows and switch to producing output - self.input_done = true; - self.group_ordering.input_done(); - let timer = elapsed_compute.timer(); - if self.spill_state.spills.is_empty() { - let batch = extract_ok!(self.emit(EmitTo::All, false)); - self.exec_state = ExecutionState::ProducingOutput(batch); - } else { - // If spill files exist, stream-merge them. - extract_ok!(self.update_merged_stream()); - self.exec_state = ExecutionState::ReadingInput; - } - timer.done(); + extract_ok!(self.set_input_done_and_produce_output()); } } } ExecutionState::ProducingOutput(batch) => { // slice off a part of the batch, if needed - let output_batch = if batch.num_rows() <= self.batch_size { - if self.input_done { - self.exec_state = ExecutionState::Done; - } else { - self.exec_state = ExecutionState::ReadingInput - } - batch + let output_batch; + let size = self.batch_size; + (self.exec_state, output_batch) = if batch.num_rows() <= size { + ( + if self.input_done { + ExecutionState::Done + } else { + ExecutionState::ReadingInput + }, + batch.clone(), + ) } else { // output first batch_size rows - let num_remaining = batch.num_rows() - self.batch_size; - let remaining = batch.slice(self.batch_size, num_remaining); - self.exec_state = ExecutionState::ProducingOutput(remaining); - batch.slice(0, self.batch_size) + let size = self.batch_size; + let num_remaining = batch.num_rows() - size; + let remaining = batch.slice(size, num_remaining); + let output = batch.slice(0, size); + (ExecutionState::ProducingOutput(remaining), output) }; return Poll::Ready(Some(Ok( output_batch.record_output(&self.baseline_metrics) @@ -673,7 +687,16 @@ impl GroupedHashAggregateStream { let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?; let mut writer = IPCWriter::new(spillfile.path(), &emit.schema())?; // TODO: slice large `sorted` and write to multiple files in parallel - writer.write(&sorted)?; + let mut offset = 0; + let total_rows = sorted.num_rows(); + + while offset < total_rows { + let length = std::cmp::min(total_rows - offset, self.batch_size); + let batch = sorted.slice(offset, length); + offset += batch.num_rows(); + writer.write(&batch)?; + } + writer.finish()?; self.spill_state.spills.push(spillfile); Ok(()) @@ -744,4 +767,31 @@ impl GroupedHashAggregateStream { self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); Ok(()) } + + /// returns true if there is a soft groups limit and the number of distinct + /// groups we have seen is over that limit + fn hit_soft_group_limit(&self) -> bool { + let Some(group_values_soft_limit) = self.group_values_soft_limit else { + return false; + }; + group_values_soft_limit <= self.group_values.len() + } + + /// common function for signalling end of processing of the input stream + fn set_input_done_and_produce_output(&mut self) -> Result<()> { + self.input_done = true; + self.group_ordering.input_done(); + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); + self.exec_state = if self.spill_state.spills.is_empty() { + let batch = self.emit(EmitTo::All, false)?; + ExecutionState::ProducingOutput(batch) + } else { + // If spill files exist, stream-merge them. + self.update_merged_stream()?; + ExecutionState::ReadingInput + }; + timer.done(); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index 3a2cac59cfdf..4f1578e220dd 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -20,19 +20,19 @@ use std::sync::Arc; use std::{any::Any, time::Instant}; -use crate::{ - display::DisplayableExecutionPlan, DisplayFormatType, ExecutionPlan, Partitioning, - Statistics, -}; -use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; -use datafusion_common::{internal_err, DataFusionError, Result}; -use futures::StreamExt; - use super::expressions::PhysicalSortExpr; use super::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; use super::{DisplayAs, Distribution, SendableRecordBatchStream}; + +use crate::display::DisplayableExecutionPlan; +use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; + +use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; +use futures::StreamExt; + /// `EXPLAIN ANALYZE` execution plan operator. This operator runs its input, /// discards the results, and then prints out an annotated plan with metrics #[derive(Debug, Clone)] @@ -115,8 +115,12 @@ impl ExecutionPlan for AnalyzeExec { /// Specifies whether this plan generates an infinite stream of records. /// If the plan does not support pipelining, but its input(s) are /// infinite, returns an error to indicate this. - fn unbounded_output(&self, _children: &[bool]) -> Result { - internal_err!("Optimization not supported for ANALYZE") + fn unbounded_output(&self, children: &[bool]) -> Result { + if children[0] { + internal_err!("Streaming execution of AnalyzeExec is not possible") + } else { + Ok(false) + } } /// Get the output partitioning of this plan @@ -195,11 +199,6 @@ impl ExecutionPlan for AnalyzeExec { futures::stream::once(output), ))) } - - fn statistics(&self) -> Statistics { - // Statistics an an ANALYZE plan are not relevant - Statistics::default() - } } /// Creates the ouput of AnalyzeExec as a RecordBatch diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index f46a228064fe..09d1ea87ca37 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -23,9 +23,12 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use super::expressions::PhysicalSortExpr; +use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use super::{DisplayAs, Statistics}; use crate::{ - DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, + DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, }; use arrow::datatypes::SchemaRef; @@ -33,12 +36,7 @@ use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::TaskContext; -use datafusion_physical_expr::OrderingEquivalenceProperties; - -use super::expressions::PhysicalSortExpr; -use super::metrics::{BaselineMetrics, MetricsSet}; -use super::DisplayAs; -use super::{metrics::ExecutionPlanMetricsSet, Statistics}; +use datafusion_physical_expr::EquivalenceProperties; use futures::stream::{Stream, StreamExt}; use log::trace; @@ -140,10 +138,6 @@ impl ExecutionPlan for CoalesceBatchesExec { self.input.equivalence_properties() } - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - self.input.ordering_equivalence_properties() - } - fn with_new_children( self: Arc, children: Vec>, @@ -174,7 +168,7 @@ impl ExecutionPlan for CoalesceBatchesExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { self.input.statistics() } } @@ -230,17 +224,17 @@ impl CoalesceBatchesStream { let _timer = cloned_time.timer(); match input_batch { Poll::Ready(x) => match x { - Some(Ok(ref batch)) => { + Some(Ok(batch)) => { if batch.num_rows() >= self.target_batch_size && self.buffer.is_empty() { - return Poll::Ready(Some(Ok(batch.clone()))); + return Poll::Ready(Some(Ok(batch))); } else if batch.num_rows() == 0 { // discard empty batches } else { // add to the buffered batches - self.buffer.push(batch.clone()); self.buffered_rows += batch.num_rows(); + self.buffer.push(batch); // check to see if we have enough batches yet if self.buffered_rows >= self.target_batch_size { // combine the batches and return @@ -302,14 +296,14 @@ pub fn concat_batches( batches.len(), row_count ); - let b = arrow::compute::concat_batches(schema, batches)?; - Ok(b) + arrow::compute::concat_batches(schema, batches) } #[cfg(test)] mod tests { use super::*; use crate::{memory::MemoryExec, repartition::RepartitionExec}; + use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::UInt32Array; diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 8eddf57ae551..bfcff2853538 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -26,11 +26,12 @@ use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::stream::{ObservedStream, RecordBatchReceiverStream}; use super::{DisplayAs, SendableRecordBatchStream, Statistics}; -use crate::{DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning}; +use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; use arrow::datatypes::SchemaRef; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::EquivalenceProperties; /// Merge execution plan executes partitions in parallel and combines them into a single /// partition. No guarantees are made about the order of the resulting partition. @@ -101,7 +102,14 @@ impl ExecutionPlan for CoalescePartitionsExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - self.input.equivalence_properties() + let mut output_eq = self.input.equivalence_properties(); + // Coalesce partitions loses existing orderings. + output_eq.clear_orderings(); + output_eq + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] } fn with_new_children( @@ -159,7 +167,7 @@ impl ExecutionPlan for CoalescePartitionsExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { self.input.statistics() } } diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index c6cfbbfbbac7..e83dc2525b9f 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -17,24 +17,29 @@ //! Defines common code used in execution plans +use std::fs; +use std::fs::{metadata, File}; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::task::{Context, Poll}; + use super::SendableRecordBatchStream; use crate::stream::RecordBatchReceiverStream; use crate::{ColumnStatistics, ExecutionPlan, Statistics}; + use arrow::datatypes::Schema; use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; use arrow::record_batch::RecordBatch; +use arrow_array::Array; +use datafusion_common::stats::Precision; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_physical_expr::expressions::{BinaryExpr, Column}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; + use futures::{Future, StreamExt, TryStreamExt}; use parking_lot::Mutex; use pin_project_lite::pin_project; -use std::fs; -use std::fs::{metadata, File}; -use std::path::{Path, PathBuf}; -use std::sync::Arc; -use std::task::{Context, Poll}; use tokio::task::JoinHandle; /// [`MemoryReservation`] used across query execution streams @@ -135,33 +140,37 @@ pub fn compute_record_batch_statistics( ) -> Statistics { let nb_rows = batches.iter().flatten().map(RecordBatch::num_rows).sum(); - let total_byte_size = batches - .iter() - .flatten() - .map(|b| b.get_array_memory_size()) - .sum(); - let projection = match projection { Some(p) => p, None => (0..schema.fields().len()).collect(), }; - let mut column_statistics = vec![ColumnStatistics::default(); projection.len()]; + let total_byte_size = batches + .iter() + .flatten() + .map(|b| { + projection + .iter() + .map(|index| b.column(*index).get_array_memory_size()) + .sum::() + }) + .sum(); + + let mut column_statistics = vec![ColumnStatistics::new_unknown(); projection.len()]; for partition in batches.iter() { for batch in partition { for (stat_index, col_index) in projection.iter().enumerate() { - *column_statistics[stat_index].null_count.get_or_insert(0) += - batch.column(*col_index).null_count(); + column_statistics[stat_index].null_count = + Precision::Exact(batch.column(*col_index).null_count()); } } } Statistics { - num_rows: Some(nb_rows), - total_byte_size: Some(total_byte_size), - column_statistics: Some(column_statistics), - is_exact: true, + num_rows: Precision::Exact(nb_rows), + total_byte_size: Precision::Exact(total_byte_size), + column_statistics, } } @@ -378,12 +387,14 @@ mod tests { use crate::memory::MemoryExec; use crate::sorts::sort::SortExec; use crate::union::UnionExec; + use arrow::compute::SortOptions; use arrow::{ array::{Float32Array, Float64Array}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; + use arrow_array::UInt64Array; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{col, Column}; @@ -671,9 +682,8 @@ mod tests { ])); let stats = compute_record_batch_statistics(&[], &schema, Some(vec![0, 1])); - assert_eq!(stats.num_rows, Some(0)); - assert!(stats.is_exact); - assert_eq!(stats.total_byte_size, Some(0)); + assert_eq!(stats.num_rows, Precision::Exact(0)); + assert_eq!(stats.total_byte_size, Precision::Exact(0)); Ok(()) } @@ -682,40 +692,46 @@ mod tests { let schema = Arc::new(Schema::new(vec![ Field::new("f32", DataType::Float32, false), Field::new("f64", DataType::Float64, false), + Field::new("u64", DataType::UInt64, false), ])); let batch = RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Float32Array::from(vec![1., 2., 3.])), Arc::new(Float64Array::from(vec![9., 8., 7.])), + Arc::new(UInt64Array::from(vec![4, 5, 6])), ], )?; + + // just select f32,f64 + let select_projection = Some(vec![0, 1]); + let byte_size = batch + .project(&select_projection.clone().unwrap()) + .unwrap() + .get_array_memory_size(); + let actual = - compute_record_batch_statistics(&[vec![batch]], &schema, Some(vec![0, 1])); + compute_record_batch_statistics(&[vec![batch]], &schema, select_projection); - let mut expected = Statistics { - is_exact: true, - num_rows: Some(3), - total_byte_size: Some(464), // this might change a bit if the way we compute the size changes - column_statistics: Some(vec![ + let expected = Statistics { + num_rows: Precision::Exact(3), + total_byte_size: Precision::Exact(byte_size), + column_statistics: vec![ ColumnStatistics { - distinct_count: None, - max_value: None, - min_value: None, - null_count: Some(0), + distinct_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + null_count: Precision::Exact(0), }, ColumnStatistics { - distinct_count: None, - max_value: None, - min_value: None, - null_count: Some(0), + distinct_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + null_count: Precision::Exact(0), }, - ]), + ], }; - // Prevent test flakiness due to undefined / changing implementation details - expected.total_byte_size = actual.total_byte_size; - assert_eq!(actual, expected); Ok(()) } diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index e4a4e113eb07..19c2847b09dc 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -20,13 +20,12 @@ use std::fmt; +use super::{accept, ExecutionPlan, ExecutionPlanVisitor}; + use arrow_schema::SchemaRef; -use datafusion_common::display::StringifiedPlan; +use datafusion_common::display::{GraphvizBuilder, PlanType, StringifiedPlan}; use datafusion_physical_expr::PhysicalSortExpr; -use super::{accept, ExecutionPlan, ExecutionPlanVisitor}; -use datafusion_common::display::{GraphvizBuilder, PlanType}; - /// Options for controlling how each [`ExecutionPlan`] should format itself #[derive(Debug, Clone, Copy)] pub enum DisplayFormatType { @@ -133,7 +132,7 @@ impl<'a> DisplayableExecutionPlan<'a> { /// ```dot /// strict digraph dot_plan { // 0[label="ProjectionExec: expr=[id@0 + 2 as employee.id + Int32(2)]",tooltip=""] - // 1[label="EmptyExec: produce_one_row=false",tooltip=""] + // 1[label="EmptyExec",tooltip=""] // 0 -> 1 // } /// ``` @@ -262,7 +261,8 @@ impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { } } if self.show_statistics { - write!(self.f, ", statistics=[{}]", plan.statistics())?; + let stats = plan.statistics().map_err(|_e| fmt::Error)?; + write!(self.f, ", statistics=[{}]", stats)?; } writeln!(self.f)?; self.indent += 1; @@ -342,7 +342,8 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { }; let statistics = if self.show_statistics { - format!("statistics=[{}]", plan.statistics()) + let stats = plan.statistics().map_err(|_e| fmt::Error)?; + format!("statistics=[{}]", stats) } else { "".to_string() }; @@ -435,3 +436,126 @@ impl<'a> fmt::Display for OutputOrderingDisplay<'a> { write!(f, "]") } } + +#[cfg(test)] +mod tests { + use std::fmt::Write; + use std::sync::Arc; + + use datafusion_common::DataFusionError; + + use crate::{DisplayAs, ExecutionPlan}; + + use super::DisplayableExecutionPlan; + + #[derive(Debug, Clone, Copy)] + enum TestStatsExecPlan { + Panic, + Error, + Ok, + } + + impl DisplayAs for TestStatsExecPlan { + fn fmt_as( + &self, + _t: crate::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "TestStatsExecPlan") + } + } + + impl ExecutionPlan for TestStatsExecPlan { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow_schema::SchemaRef { + Arc::new(arrow_schema::Schema::empty()) + } + + fn output_partitioning(&self) -> datafusion_physical_expr::Partitioning { + datafusion_physical_expr::Partitioning::UnknownPartitioning(1) + } + + fn output_ordering( + &self, + ) -> Option<&[datafusion_physical_expr::PhysicalSortExpr]> { + None + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion_common::Result> { + unimplemented!() + } + + fn execute( + &self, + _: usize, + _: Arc, + ) -> datafusion_common::Result + { + todo!() + } + + fn statistics(&self) -> datafusion_common::Result { + match self { + Self::Panic => panic!("expected panic"), + Self::Error => { + Err(DataFusionError::Internal("expected error".to_string())) + } + Self::Ok => Ok(datafusion_common::Statistics::new_unknown( + self.schema().as_ref(), + )), + } + } + } + + fn test_stats_display(exec: TestStatsExecPlan, show_stats: bool) { + let display = + DisplayableExecutionPlan::new(&exec).set_show_statistics(show_stats); + + let mut buf = String::new(); + write!(&mut buf, "{}", display.one_line()).unwrap(); + let buf = buf.trim(); + assert_eq!(buf, "TestStatsExecPlan"); + } + + #[test] + fn test_display_when_stats_panic_with_no_show_stats() { + test_stats_display(TestStatsExecPlan::Panic, false); + } + + #[test] + fn test_display_when_stats_error_with_no_show_stats() { + test_stats_display(TestStatsExecPlan::Error, false); + } + + #[test] + fn test_display_when_stats_ok_with_no_show_stats() { + test_stats_display(TestStatsExecPlan::Ok, false); + } + + #[test] + #[should_panic(expected = "expected panic")] + fn test_display_when_stats_panic_with_show_stats() { + test_stats_display(TestStatsExecPlan::Panic, true); + } + + #[test] + #[should_panic(expected = "Error")] // fmt::Error + fn test_display_when_stats_error_with_show_stats() { + test_stats_display(TestStatsExecPlan::Error, true); + } + + #[test] + fn test_display_when_stats_ok_with_show_stats() { + test_stats_display(TestStatsExecPlan::Ok, false); + } +} diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index 675dac9ad265..41c8dbed1453 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -15,28 +15,25 @@ // specific language governing permissions and limitations // under the License. -//! EmptyRelation execution plan +//! EmptyRelation with produce_one_row=false execution plan use std::any::Any; use std::sync::Arc; -use crate::{memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning}; -use arrow::array::{ArrayRef, NullArray}; -use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; -use arrow::record_batch::RecordBatch; -use datafusion_common::{internal_err, DataFusionError, Result}; -use log::trace; - use super::expressions::PhysicalSortExpr; use super::{common, DisplayAs, SendableRecordBatchStream, Statistics}; +use crate::{memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning}; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; -/// Execution plan for empty relation (produces no rows) +use log::trace; + +/// Execution plan for empty relation with produce_one_row=false #[derive(Debug)] pub struct EmptyExec { - /// Specifies whether this exec produces a row or not - produce_one_row: bool, /// The schema for the produced row schema: SchemaRef, /// Number of partitions @@ -45,9 +42,8 @@ pub struct EmptyExec { impl EmptyExec { /// Create a new EmptyExec - pub fn new(produce_one_row: bool, schema: SchemaRef) -> Self { + pub fn new(schema: SchemaRef) -> Self { EmptyExec { - produce_one_row, schema, partitions: 1, } @@ -59,36 +55,8 @@ impl EmptyExec { self } - /// Specifies whether this exec produces a row or not - pub fn produce_one_row(&self) -> bool { - self.produce_one_row - } - fn data(&self) -> Result> { - let batch = if self.produce_one_row { - let n_field = self.schema.fields.len(); - // hack for https://github.com/apache/arrow-datafusion/pull/3242 - let n_field = if n_field == 0 { 1 } else { n_field }; - vec![RecordBatch::try_new( - Arc::new(Schema::new( - (0..n_field) - .map(|i| { - Field::new(format!("placeholder_{i}"), DataType::Null, true) - }) - .collect::(), - )), - (0..n_field) - .map(|_i| { - let ret: ArrayRef = Arc::new(NullArray::new(1)); - ret - }) - .collect(), - )?] - } else { - vec![] - }; - - Ok(batch) + Ok(vec![]) } } @@ -100,7 +68,7 @@ impl DisplayAs for EmptyExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "EmptyExec: produce_one_row={}", self.produce_one_row) + write!(f, "EmptyExec") } } } @@ -133,10 +101,7 @@ impl ExecutionPlan for EmptyExec { self: Arc, _: Vec>, ) -> Result> { - Ok(Arc::new(EmptyExec::new( - self.produce_one_row, - self.schema.clone(), - ))) + Ok(Arc::new(EmptyExec::new(self.schema.clone()))) } fn execute( @@ -161,11 +126,15 @@ impl ExecutionPlan for EmptyExec { )?)) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { let batch = self .data() .expect("Create empty RecordBatch should not fail"); - common::compute_record_batch_statistics(&[batch], &self.schema, None) + Ok(common::compute_record_batch_statistics( + &[batch], + &self.schema, + None, + )) } } @@ -180,7 +149,7 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); let schema = test::aggr_test_schema(); - let empty = EmptyExec::new(false, schema.clone()); + let empty = EmptyExec::new(schema.clone()); assert_eq!(empty.schema(), schema); // we should have no results @@ -194,16 +163,11 @@ mod tests { #[test] fn with_new_children() -> Result<()> { let schema = test::aggr_test_schema(); - let empty = Arc::new(EmptyExec::new(false, schema.clone())); - let empty_with_row = Arc::new(EmptyExec::new(true, schema)); + let empty = Arc::new(EmptyExec::new(schema.clone())); let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?.into(); assert_eq!(empty.schema(), empty2.schema()); - let empty_with_row_2 = - with_new_children_if_necessary(empty_with_row.clone(), vec![])?.into(); - assert_eq!(empty_with_row.schema(), empty_with_row_2.schema()); - let too_many_kids = vec![empty2]; assert!( with_new_children_if_necessary(empty, too_many_kids).is_err(), @@ -216,44 +180,11 @@ mod tests { async fn invalid_execute() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = test::aggr_test_schema(); - let empty = EmptyExec::new(false, schema); + let empty = EmptyExec::new(schema); // ask for the wrong partition assert!(empty.execute(1, task_ctx.clone()).is_err()); assert!(empty.execute(20, task_ctx).is_err()); Ok(()) } - - #[tokio::test] - async fn produce_one_row() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); - let schema = test::aggr_test_schema(); - let empty = EmptyExec::new(true, schema); - - let iter = empty.execute(0, task_ctx)?; - let batches = common::collect(iter).await?; - - // should have one item - assert_eq!(batches.len(), 1); - - Ok(()) - } - - #[tokio::test] - async fn produce_one_row_multiple_partition() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); - let schema = test::aggr_test_schema(); - let partitions = 3; - let empty = EmptyExec::new(true, schema).with_partitions(partitions); - - for n in 0..partitions { - let iter = empty.execute(n, task_ctx.clone())?; - let batches = common::collect(iter).await?; - - // should have one item - assert_eq!(batches.len(), 1); - } - - Ok(()) - } } diff --git a/datafusion/physical-plan/src/explain.rs b/datafusion/physical-plan/src/explain.rs index 8d6bf4105f6a..e4904ddd3410 100644 --- a/datafusion/physical-plan/src/explain.rs +++ b/datafusion/physical-plan/src/explain.rs @@ -20,19 +20,18 @@ use std::any::Any; use std::sync::Arc; -use datafusion_common::display::StringifiedPlan; +use super::expressions::PhysicalSortExpr; +use super::{DisplayAs, SendableRecordBatchStream}; +use crate::stream::RecordBatchStreamAdapter; +use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; +use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; +use datafusion_common::display::StringifiedPlan; use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_execution::TaskContext; -use crate::{DisplayFormatType, ExecutionPlan, Partitioning, Statistics}; -use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; use log::trace; -use super::DisplayAs; -use super::{expressions::PhysicalSortExpr, SendableRecordBatchStream}; -use crate::stream::RecordBatchStreamAdapter; -use datafusion_execution::TaskContext; - /// Explain execution plan operator. This operator contains the string /// values of the various plans it has when it is created, and passes /// them to its output. @@ -168,11 +167,6 @@ impl ExecutionPlan for ExplainExec { futures::stream::iter(vec![Ok(record_batch)]), ))) } - - fn statistics(&self) -> Statistics { - // Statistics an EXPLAIN plan are not relevant - Statistics::default() - } } /// If this plan should be shown, given the previous plan that was diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 4a8b18914411..56a1b4e17821 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -27,27 +27,27 @@ use super::expressions::PhysicalSortExpr; use super::{ ColumnStatistics, DisplayAs, RecordBatchStream, SendableRecordBatchStream, Statistics, }; - use crate::{ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, - Column, DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, + Column, DisplayFormatType, ExecutionPlan, Partitioning, }; use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; +use datafusion_common::stats::Precision; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::BinaryExpr; +use datafusion_physical_expr::intervals::utils::check_support; +use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - analyze, split_conjunction, AnalysisContext, ExprBoundaries, - OrderingEquivalenceProperties, PhysicalExpr, + analyze, split_conjunction, AnalysisContext, EquivalenceProperties, ExprBoundaries, + PhysicalExpr, }; -use datafusion_physical_expr::intervals::utils::check_support; -use datafusion_physical_expr::utils::collect_columns; use futures::stream::{Stream, StreamExt}; use log::trace; @@ -61,6 +61,8 @@ pub struct FilterExec { input: Arc, /// Execution metrics metrics: ExecutionPlanMetricsSet, + /// Selectivity for statistics. 0 = no rows, 100 all rows + default_selectivity: u8, } impl FilterExec { @@ -74,6 +76,7 @@ impl FilterExec { predicate, input: input.clone(), metrics: ExecutionPlanMetricsSet::new(), + default_selectivity: 20, }), other => { plan_err!("Filter predicate must return boolean values, not {other:?}") @@ -81,6 +84,17 @@ impl FilterExec { } } + pub fn with_default_selectivity( + mut self, + default_selectivity: u8, + ) -> Result { + if default_selectivity > 100 { + return plan_err!("Default flter selectivity needs to be less than 100"); + } + self.default_selectivity = default_selectivity; + Ok(self) + } + /// The expression to filter on. This expression must evaluate to a boolean value. pub fn predicate(&self) -> &Arc { &self.predicate @@ -90,6 +104,11 @@ impl FilterExec { pub fn input(&self) -> &Arc { &self.input } + + /// The default selectivity + pub fn default_selectivity(&self) -> u8 { + self.default_selectivity + } } impl DisplayAs for FilterExec { @@ -144,39 +163,33 @@ impl ExecutionPlan for FilterExec { } fn equivalence_properties(&self) -> EquivalenceProperties { + let stats = self.statistics().unwrap(); // Combine the equal predicates with the input equivalence properties - let mut input_properties = self.input.equivalence_properties(); - let (equal_pairs, _ne_pairs) = collect_columns_from_predicate(&self.predicate); - for new_condition in equal_pairs { - input_properties.add_equal_conditions(new_condition) + let mut result = self.input.equivalence_properties(); + let (equal_pairs, _) = collect_columns_from_predicate(&self.predicate); + for (lhs, rhs) in equal_pairs { + let lhs_expr = Arc::new(lhs.clone()) as _; + let rhs_expr = Arc::new(rhs.clone()) as _; + result.add_equal_conditions(&lhs_expr, &rhs_expr) } - input_properties - } - - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - let stats = self.statistics(); // Add the columns that have only one value (singleton) after filtering to constants. - if let Some(col_stats) = stats.column_statistics { - let constants = collect_columns(self.predicate()) - .into_iter() - .filter(|column| col_stats[column.index()].is_singleton()) - .map(|column| Arc::new(column) as Arc) - .collect::>(); - let filter_oeq = self.input.ordering_equivalence_properties(); - filter_oeq.with_constants(constants) - } else { - self.input.ordering_equivalence_properties() - } + let constants = collect_columns(self.predicate()) + .into_iter() + .filter(|column| stats.column_statistics[column.index()].is_singleton()) + .map(|column| Arc::new(column) as _); + result.add_constants(constants) } fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { - Ok(Arc::new(FilterExec::try_new( - self.predicate.clone(), - children[0].clone(), - )?)) + FilterExec::try_new(self.predicate.clone(), children.swap_remove(0)) + .and_then(|e| { + let selectivity = e.default_selectivity(); + e.with_default_selectivity(selectivity) + }) + .map(|e| Arc::new(e) as _) } fn execute( @@ -200,56 +213,44 @@ impl ExecutionPlan for FilterExec { /// The output statistics of a filtering operation can be estimated if the /// predicate's selectivity value can be determined for the incoming data. - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { let predicate = self.predicate(); - if !check_support(predicate) { - return Statistics::default(); + let input_stats = self.input.statistics()?; + let schema = self.schema(); + if !check_support(predicate, &schema) { + let selectivity = self.default_selectivity as f64 / 100.0; + let mut stats = input_stats.into_inexact(); + stats.num_rows = stats.num_rows.with_estimated_selectivity(selectivity); + stats.total_byte_size = stats + .total_byte_size + .with_estimated_selectivity(selectivity); + return Ok(stats); } - let input_stats = self.input.statistics(); - let input_column_stats = match input_stats.column_statistics { - Some(stats) => stats, - None => self - .schema() - .fields - .iter() - .map(|field| { - ColumnStatistics::new_with_unbounded_column(field.data_type()) - }) - .collect::>(), - }; - - let starter_ctx = - AnalysisContext::from_statistics(&self.input.schema(), &input_column_stats); + let num_rows = input_stats.num_rows; + let total_byte_size = input_stats.total_byte_size; + let input_analysis_ctx = AnalysisContext::try_from_statistics( + &self.input.schema(), + &input_stats.column_statistics, + )?; - let analysis_ctx = match analyze(predicate, starter_ctx) { - Ok(ctx) => ctx, - Err(_) => return Statistics::default(), - }; + let analysis_ctx = analyze(predicate, input_analysis_ctx, &self.schema())?; + // Estimate (inexact) selectivity of predicate let selectivity = analysis_ctx.selectivity.unwrap_or(1.0); + let num_rows = num_rows.with_estimated_selectivity(selectivity); + let total_byte_size = total_byte_size.with_estimated_selectivity(selectivity); - let num_rows = input_stats - .num_rows - .map(|num| (num as f64 * selectivity).ceil() as usize); - let total_byte_size = input_stats - .total_byte_size - .map(|size| (size as f64 * selectivity).ceil() as usize); - - let column_statistics = if let Some(analysis_boundaries) = analysis_ctx.boundaries - { - collect_new_statistics(input_column_stats, selectivity, analysis_boundaries) - } else { - input_column_stats - }; - - Statistics { + let column_statistics = collect_new_statistics( + &input_stats.column_statistics, + analysis_ctx.boundaries, + ); + Ok(Statistics { num_rows, total_byte_size, - column_statistics: Some(column_statistics), - is_exact: Default::default(), - } + column_statistics, + }) } } @@ -258,11 +259,9 @@ impl ExecutionPlan for FilterExec { /// is adjusted by using the next/previous value for its data type to convert /// it into a closed bound. fn collect_new_statistics( - input_column_stats: Vec, - selectivity: f64, + input_column_stats: &[ColumnStatistics], analysis_boundaries: Vec, ) -> Vec { - let nonempty_columns = selectivity > 0.0; analysis_boundaries .into_iter() .enumerate() @@ -275,12 +274,17 @@ fn collect_new_statistics( .. }, )| { - let closed_interval = interval.close_bounds(); + let (lower, upper) = interval.into_bounds(); + let (min_value, max_value) = if lower.eq(&upper) { + (Precision::Exact(lower), Precision::Exact(upper)) + } else { + (Precision::Inexact(lower), Precision::Inexact(upper)) + }; ColumnStatistics { - null_count: input_column_stats[idx].null_count, - max_value: nonempty_columns.then_some(closed_interval.upper.value), - min_value: nonempty_columns.then_some(closed_interval.lower.value), - distinct_count, + null_count: input_column_stats[idx].null_count.clone().to_inexact(), + max_value, + min_value, + distinct_count: distinct_count.to_inexact(), } }, ) @@ -306,7 +310,7 @@ pub(crate) fn batch_filter( ) -> Result { predicate .evaluate(batch) - .map(|v| v.into_array(batch.num_rows())) + .and_then(|v| v.into_array(batch.num_rows())) .and_then(|array| { Ok(as_boolean_array(&array)?) // apply filter array to record batch @@ -364,17 +368,16 @@ impl RecordBatchStream for FilterExecStream { /// Return the equals Column-Pairs and Non-equals Column-Pairs fn collect_columns_from_predicate(predicate: &Arc) -> EqualAndNonEqual { - let mut eq_predicate_columns: Vec<(&Column, &Column)> = Vec::new(); - let mut ne_predicate_columns: Vec<(&Column, &Column)> = Vec::new(); + let mut eq_predicate_columns = Vec::<(&Column, &Column)>::new(); + let mut ne_predicate_columns = Vec::<(&Column, &Column)>::new(); let predicates = split_conjunction(predicate); predicates.into_iter().for_each(|p| { if let Some(binary) = p.as_any().downcast_ref::() { - let left = binary.left(); - let right = binary.right(); - if left.as_any().is::() && right.as_any().is::() { - let left_column = left.as_any().downcast_ref::().unwrap(); - let right_column = right.as_any().downcast_ref::().unwrap(); + if let (Some(left_column), Some(right_column)) = ( + binary.left().as_any().downcast_ref::(), + binary.right().as_any().downcast_ref::(), + ) { match binary.op() { Operator::Eq => { eq_predicate_columns.push((left_column, right_column)) @@ -396,18 +399,18 @@ pub type EqualAndNonEqual<'a> = #[cfg(test)] mod tests { + use std::iter::Iterator; + use std::sync::Arc; use super::*; use crate::expressions::*; use crate::test; use crate::test::exec::StatisticsExec; use crate::ExecutionPlan; + use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ColumnStatistics; - use datafusion_common::ScalarValue; + use datafusion_common::{ColumnStatistics, ScalarValue}; use datafusion_expr::Operator; - use std::iter::Iterator; - use std::sync::Arc; #[tokio::test] async fn collect_columns_predicates() -> Result<()> { @@ -460,14 +463,13 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let input = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(100), - total_byte_size: Some(100 * bytes_per_row), - column_statistics: Some(vec![ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(100))), + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(100 * bytes_per_row), + column_statistics: vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), ..Default::default() - }]), - ..Default::default() + }], }, schema.clone(), )); @@ -480,16 +482,19 @@ mod tests { let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics(); - assert_eq!(statistics.num_rows, Some(25)); - assert_eq!(statistics.total_byte_size, Some(25 * bytes_per_row)); + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Inexact(25)); + assert_eq!( + statistics.total_byte_size, + Precision::Inexact(25 * bytes_per_row) + ); assert_eq!( statistics.column_statistics, - Some(vec![ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(25))), + vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(25))), ..Default::default() - }]) + }] ); Ok(()) @@ -502,13 +507,13 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let input = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(100), - column_statistics: Some(vec![ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(100))), + num_rows: Precision::Inexact(100), + column_statistics: vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), ..Default::default() - }]), - ..Default::default() + }], + total_byte_size: Precision::Absent, }, schema.clone(), )); @@ -527,15 +532,15 @@ mod tests { sub_filter, )?); - let statistics = filter.statistics(); - assert_eq!(statistics.num_rows, Some(16)); + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Inexact(16)); assert_eq!( statistics.column_statistics, - Some(vec![ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(10))), - max_value: Some(ScalarValue::Int32(Some(25))), + vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(10))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(25))), ..Default::default() - }]) + }] ); Ok(()) @@ -552,20 +557,20 @@ mod tests { ]); let input = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(100), - column_statistics: Some(vec![ + num_rows: Precision::Inexact(100), + column_statistics: vec![ ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), ..Default::default() }, ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(50))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(50))), ..Default::default() }, - ]), - ..Default::default() + ], + total_byte_size: Precision::Absent, }, schema.clone(), )); @@ -587,28 +592,28 @@ mod tests { binary(col("a", &schema)?, Operator::GtEq, lit(10i32), &schema)?, b_gt_5, )?); - let statistics = filter.statistics(); + let statistics = filter.statistics()?; // On a uniform distribution, only fifteen rows will satisfy the // filter that 'a' proposed (a >= 10 AND a <= 25) (15/100) and only // 5 rows will satisfy the filter that 'b' proposed (b > 45) (5/50). // // Which would result with a selectivity of '15/100 * 5/50' or 0.015 // and that means about %1.5 of the all rows (rounded up to 2 rows). - assert_eq!(statistics.num_rows, Some(2)); + assert_eq!(statistics.num_rows, Precision::Inexact(2)); assert_eq!( statistics.column_statistics, - Some(vec![ + vec![ ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(10))), - max_value: Some(ScalarValue::Int32(Some(25))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(10))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(25))), ..Default::default() }, ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(46))), - max_value: Some(ScalarValue::Int32(Some(50))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(46))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(50))), ..Default::default() } - ]) + ] ); Ok(()) @@ -620,12 +625,7 @@ mod tests { // a: min=???, max=??? (missing) let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let input = Arc::new(StatisticsExec::new( - Statistics { - column_statistics: Some(vec![ColumnStatistics { - ..Default::default() - }]), - ..Default::default() - }, + Statistics::new_unknown(&schema), schema.clone(), )); @@ -637,8 +637,8 @@ mod tests { let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics(); - assert_eq!(statistics.num_rows, None); + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Absent); Ok(()) } @@ -656,26 +656,25 @@ mod tests { ]); let input = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(1000), - total_byte_size: Some(4000), - column_statistics: Some(vec![ + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), ..Default::default() }, ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(3))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(3))), ..Default::default() }, ColumnStatistics { - min_value: Some(ScalarValue::Float32(Some(1000.0))), - max_value: Some(ScalarValue::Float32(Some(1100.0))), + min_value: Precision::Inexact(ScalarValue::Float32(Some(1000.0))), + max_value: Precision::Inexact(ScalarValue::Float32(Some(1100.0))), ..Default::default() }, - ]), - ..Default::default() + ], }, schema, )); @@ -711,47 +710,51 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics(); + let statistics = filter.statistics()?; // 0.5 (from a) * 0.333333... (from b) * 0.798387... (from c) ≈ 0.1330... // num_rows after ceil => 133.0... => 134 // total_byte_size after ceil => 532.0... => 533 - assert_eq!(statistics.num_rows, Some(134)); - assert_eq!(statistics.total_byte_size, Some(533)); + assert_eq!(statistics.num_rows, Precision::Inexact(134)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(533)); let exp_col_stats = vec![ ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(4))), - max_value: Some(ScalarValue::Int32(Some(53))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(4))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(53))), ..Default::default() }, ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(3))), - max_value: Some(ScalarValue::Int32(Some(3))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(3))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(3))), ..Default::default() }, ColumnStatistics { - min_value: Some(ScalarValue::Float32(Some(1000.0))), - max_value: Some(ScalarValue::Float32(Some(1075.0))), + min_value: Precision::Inexact(ScalarValue::Float32(Some(1000.0))), + max_value: Precision::Inexact(ScalarValue::Float32(Some(1075.0))), ..Default::default() }, ]; let _ = exp_col_stats .into_iter() - .zip(statistics.column_statistics.unwrap()) + .zip(statistics.column_statistics) .map(|(expected, actual)| { - if actual.min_value.clone().unwrap().data_type().is_floating() { - // Windows rounds arithmetic operation results differently for floating point numbers. - // Therefore, we check if the actual values are in an epsilon range. - let actual_min = actual.min_value.unwrap(); - let actual_max = actual.max_value.unwrap(); - let expected_min = expected.min_value.unwrap(); - let expected_max = expected.max_value.unwrap(); - let eps = ScalarValue::Float32(Some(1e-6)); - - assert!(actual_min.sub(&expected_min).unwrap() < eps); - assert!(actual_min.sub(&expected_min).unwrap() < eps); - - assert!(actual_max.sub(&expected_max).unwrap() < eps); - assert!(actual_max.sub(&expected_max).unwrap() < eps); + if let Some(val) = actual.min_value.get_value() { + if val.data_type().is_floating() { + // Windows rounds arithmetic operation results differently for floating point numbers. + // Therefore, we check if the actual values are in an epsilon range. + let actual_min = actual.min_value.get_value().unwrap(); + let actual_max = actual.max_value.get_value().unwrap(); + let expected_min = expected.min_value.get_value().unwrap(); + let expected_max = expected.max_value.get_value().unwrap(); + let eps = ScalarValue::Float32(Some(1e-6)); + + assert!(actual_min.sub(expected_min).unwrap() < eps); + assert!(actual_min.sub(expected_min).unwrap() < eps); + + assert!(actual_max.sub(expected_max).unwrap() < eps); + assert!(actual_max.sub(expected_max).unwrap() < eps); + } else { + assert_eq!(actual, expected); + } } else { assert_eq!(actual, expected); } @@ -771,21 +774,20 @@ mod tests { ]); let input = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(1000), - total_byte_size: Some(4000), - column_statistics: Some(vec![ + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), ..Default::default() }, ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(3))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(3))), ..Default::default() }, - ]), - ..Default::default() + ], }, schema, )); @@ -804,13 +806,13 @@ mod tests { )), )); // Since filter predicate passes all entries, statistics after filter shouldn't change. - let expected = input.statistics().column_statistics; + let expected = input.statistics()?.column_statistics; let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics(); + let statistics = filter.statistics()?; - assert_eq!(statistics.num_rows, Some(1000)); - assert_eq!(statistics.total_byte_size, Some(4000)); + assert_eq!(statistics.num_rows, Precision::Inexact(1000)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(4000)); assert_eq!(statistics.column_statistics, expected); Ok(()) @@ -827,21 +829,20 @@ mod tests { ]); let input = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(1000), - total_byte_size: Some(4000), - column_statistics: Some(vec![ + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), ..Default::default() }, ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(3))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(3))), ..Default::default() }, - ]), - ..Default::default() + ], }, schema, )); @@ -861,24 +862,24 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics(); + let statistics = filter.statistics()?; - assert_eq!(statistics.num_rows, Some(0)); - assert_eq!(statistics.total_byte_size, Some(0)); + assert_eq!(statistics.num_rows, Precision::Inexact(0)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(0)); assert_eq!( statistics.column_statistics, - Some(vec![ + vec![ ColumnStatistics { - min_value: None, - max_value: None, + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), ..Default::default() }, ColumnStatistics { - min_value: None, - max_value: None, + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(3))), ..Default::default() }, - ]) + ] ); Ok(()) @@ -892,21 +893,20 @@ mod tests { ]); let input = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(1000), - total_byte_size: Some(4000), - column_statistics: Some(vec![ + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), ..Default::default() }, ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), ..Default::default() }, - ]), - ..Default::default() + ], }, schema, )); @@ -918,26 +918,143 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics(); + let statistics = filter.statistics()?; - assert_eq!(statistics.num_rows, Some(490)); - assert_eq!(statistics.total_byte_size, Some(1960)); + assert_eq!(statistics.num_rows, Precision::Inexact(490)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(1960)); assert_eq!( statistics.column_statistics, - Some(vec![ + vec![ ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(49))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(49))), ..Default::default() }, ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), ..Default::default() }, - ]) + ] ); Ok(()) } + + #[tokio::test] + async fn test_empty_input_statistics() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&schema), + schema, + )); + // WHERE a <= 10 AND 0 <= a - 5 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::LtEq, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), + Operator::LtEq, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Minus, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )), + )), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let filter_statistics = filter.statistics()?; + + let expected_filter_statistics = Statistics { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics { + null_count: Precision::Absent, + min_value: Precision::Inexact(ScalarValue::Int32(Some(5))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(10))), + distinct_count: Precision::Absent, + }], + }; + + assert_eq!(filter_statistics, expected_filter_statistics); + + Ok(()) + } + + #[tokio::test] + async fn test_statistics_with_constant_column() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&schema), + schema, + )); + // WHERE a = 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let filter_statistics = filter.statistics()?; + // First column is "a", and it is a column with only one value after the filter. + assert!(filter_statistics.column_statistics[0].is_singleton()); + + Ok(()) + } + + #[tokio::test] + async fn test_validation_filter_selectivity() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&schema), + schema, + )); + // WHERE a = 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + let filter = FilterExec::try_new(predicate, input)?; + assert!(filter.with_default_selectivity(120).is_err()); + Ok(()) + } + + #[tokio::test] + async fn test_custom_filter_selectivity() -> Result<()> { + // Need a decimal to trigger inexact selectivity + let schema = + Schema::new(vec![Field::new("a", DataType::Decimal128(2, 3), false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ColumnStatistics { + ..Default::default() + }], + }, + schema, + )); + // WHERE a = 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Decimal128(Some(10), 10, 10))), + )); + let filter = FilterExec::try_new(predicate, input)?; + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Inexact(200)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(800)); + let filter = filter.with_default_selectivity(40)?; + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Inexact(400)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(1600)); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index 8b467461ddad..81cdfd753fe6 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -17,27 +17,28 @@ //! Execution plan for writing data to [`DataSink`]s +use std::any::Any; +use std::fmt; +use std::fmt::Debug; +use std::sync::Arc; + use super::expressions::PhysicalSortExpr; use super::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, - Statistics, }; +use crate::metrics::MetricsSet; +use crate::stream::RecordBatchStreamAdapter; + use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow_array::{ArrayRef, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{Distribution, PhysicalSortRequirement}; + use async_trait::async_trait; -use core::fmt; -use datafusion_common::Result; -use datafusion_physical_expr::PhysicalSortRequirement; use futures::StreamExt; -use std::any::Any; -use std::fmt::Debug; -use std::sync::Arc; - -use crate::stream::RecordBatchStreamAdapter; -use datafusion_common::{exec_err, internal_err, DataFusionError}; -use datafusion_execution::TaskContext; /// `DataSink` implements writing streams of [`RecordBatch`]es to /// user defined destinations. @@ -46,6 +47,16 @@ use datafusion_execution::TaskContext; /// output. #[async_trait] pub trait DataSink: DisplayAs + Debug + Send + Sync { + /// Returns the data sink as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// Return a snapshot of the [MetricsSet] for this + /// [DataSink]. + /// + /// See [ExecutionPlan::metrics()] for more details + fn metrics(&self) -> Option; + // TODO add desired input ordering // How does this sink want its input ordered? @@ -56,7 +67,7 @@ pub trait DataSink: DisplayAs + Debug + Send + Sync { /// or rollback required. async fn write_all( &self, - data: Vec, + data: SendableRecordBatchStream, context: &Arc, ) -> Result; } @@ -73,6 +84,8 @@ pub struct FileSinkExec { sink_schema: SchemaRef, /// Schema describing the structure of the output data. count_schema: SchemaRef, + /// Optional required sort order for output data. + sort_order: Option>, } impl fmt::Debug for FileSinkExec { @@ -87,12 +100,14 @@ impl FileSinkExec { input: Arc, sink: Arc, sink_schema: SchemaRef, + sort_order: Option>, ) -> Self { Self { input, sink, sink_schema, count_schema: make_count_schema(), + sort_order, } } @@ -136,16 +151,24 @@ impl FileSinkExec { } } - fn execute_all_input_streams( - &self, - context: Arc, - ) -> Result> { - let n_input_parts = self.input.output_partitioning().partition_count(); - let mut streams = Vec::with_capacity(n_input_parts); - for part in 0..n_input_parts { - streams.push(self.execute_input_stream(part, context.clone())?); - } - Ok(streams) + /// Input execution plan + pub fn input(&self) -> &Arc { + &self.input + } + + /// Returns insert sink + pub fn sink(&self) -> &dyn DataSink { + self.sink.as_ref() + } + + /// Optional sort order for output data + pub fn sort_order(&self) -> &Option> { + &self.sort_order + } + + /// Returns the metrics of the underlying [DataSink] + pub fn metrics(&self) -> Option { + self.sink.metrics() } } @@ -157,7 +180,7 @@ impl DisplayAs for FileSinkExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "InsertExec: sink=")?; + write!(f, "FileSinkExec: sink=")?; self.sink.fmt_as(t, f) } } @@ -184,28 +207,29 @@ impl ExecutionPlan for FileSinkExec { } fn benefits_from_input_partitioning(&self) -> Vec { - // Incoming number of partitions is taken to be the - // number of files the query is required to write out. - // The optimizer should not change this number. - // Parrallelism is handled within the appropriate DataSink + // DataSink is responsible for dynamically partitioning its + // own input at execution time. vec![false] } + fn required_input_distribution(&self) -> Vec { + // DataSink is responsible for dynamically partitioning its + // own input at execution time, and so requires a single input partition. + vec![Distribution::SinglePartition; self.children().len()] + } + fn required_input_ordering(&self) -> Vec>> { - // Require that the InsertExec gets the data in the order the - // input produced it (otherwise the optimizer may chose to reorder - // the input which could result in unintended / poor UX) - // - // More rationale: - // https://github.com/apache/arrow-datafusion/pull/6354#discussion_r1195284178 - vec![self - .input - .output_ordering() - .map(PhysicalSortRequirement::from_sort_exprs)] + // The required input ordering is set externally (e.g. by a `ListingTable`). + // Otherwise, there is no specific requirement (i.e. `sort_expr` is `None`). + vec![self.sort_order.as_ref().cloned()] } fn maintains_input_order(&self) -> Vec { - vec![false] + // Maintains ordering in the sense that the written file will reflect + // the ordering of the input. For more context, see: + // + // https://github.com/apache/arrow-datafusion/pull/6354#discussion_r1195284178 + vec![true] } fn children(&self) -> Vec> { @@ -221,6 +245,7 @@ impl ExecutionPlan for FileSinkExec { sink: self.sink.clone(), sink_schema: self.sink_schema.clone(), count_schema: self.count_schema.clone(), + sort_order: self.sort_order.clone(), })) } @@ -238,7 +263,7 @@ impl ExecutionPlan for FileSinkExec { if partition != 0 { return internal_err!("FileSinkExec can only be called on partition 0!"); } - let data = self.execute_all_input_streams(context.clone())?; + let data = self.execute_input_stream(0, context.clone())?; let count_schema = self.count_schema.clone(); let sink = self.sink.clone(); @@ -253,10 +278,6 @@ impl ExecutionPlan for FileSinkExec { stream, ))) } - - fn statistics(&self) -> Statistics { - Statistics::default() - } } /// Create a output record batch with a count diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 4ba29524b3e2..938c9e4d343d 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -18,31 +18,31 @@ //! Defines the cross join plan for loading the left side of the cross join //! and producing batches in parallel for the right partitions -use futures::{ready, StreamExt}; -use futures::{Stream, TryStreamExt}; use std::{any::Any, sync::Arc, task::Poll}; -use arrow::datatypes::{Fields, Schema, SchemaRef}; -use arrow::record_batch::RecordBatch; - +use super::utils::{ + adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, OnceFut, +}; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::DisplayAs; use crate::{ coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec, - ColumnStatistics, DisplayFormatType, Distribution, EquivalenceProperties, - ExecutionPlan, Partitioning, PhysicalSortExpr, RecordBatchStream, - SendableRecordBatchStream, Statistics, + ColumnStatistics, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use async_trait::async_trait; -use datafusion_common::{plan_err, DataFusionError}; -use datafusion_common::{Result, ScalarValue}; + +use arrow::datatypes::{Fields, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use arrow_array::RecordBatchOptions; +use datafusion_common::stats::Precision; +use datafusion_common::{plan_err, DataFusionError, JoinType, Result, ScalarValue}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::join_equivalence_properties; +use datafusion_physical_expr::EquivalenceProperties; -use super::utils::{ - adjust_right_output_partitioning, cross_join_equivalence_properties, - BuildProbeJoinMetrics, OnceAsync, OnceFut, -}; +use async_trait::async_trait; +use futures::{ready, Stream, StreamExt, TryStreamExt}; /// Data of the left side type JoinLeftData = (RecordBatch, MemoryReservation); @@ -105,12 +105,11 @@ async fn load_left_input( reservation: MemoryReservation, ) -> Result { // merge all left parts into a single stream - let merge = { - if left.output_partitioning().partition_count() != 1 { - Arc::new(CoalescePartitionsExec::new(left.clone())) - } else { - left.clone() - } + let left_schema = left.schema(); + let merge = if left.output_partitioning().partition_count() != 1 { + Arc::new(CoalescePartitionsExec::new(left)) + } else { + left }; let stream = merge.execute(0, context)?; @@ -135,7 +134,7 @@ async fn load_left_input( ) .await?; - let merged_batch = concat_batches(&left.schema(), &batches, num_rows)?; + let merged_batch = concat_batches(&left_schema, &batches, num_rows)?; Ok((merged_batch, reservation)) } @@ -216,12 +215,14 @@ impl ExecutionPlan for CrossJoinExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - let left_columns_len = self.left.schema().fields.len(); - cross_join_equivalence_properties( + join_equivalence_properties( self.left.equivalence_properties(), self.right.equivalence_properties(), - left_columns_len, + &JoinType::Full, self.schema(), + &[false, false], + None, + &[], ) } @@ -257,77 +258,55 @@ impl ExecutionPlan for CrossJoinExec { })) } - fn statistics(&self) -> Statistics { - stats_cartesian_product( - self.left.statistics(), - self.left.schema().fields().len(), - self.right.statistics(), - self.right.schema().fields().len(), - ) + fn statistics(&self) -> Result { + Ok(stats_cartesian_product( + self.left.statistics()?, + self.right.statistics()?, + )) } } /// [left/right]_col_count are required in case the column statistics are None fn stats_cartesian_product( left_stats: Statistics, - left_col_count: usize, right_stats: Statistics, - right_col_count: usize, ) -> Statistics { let left_row_count = left_stats.num_rows; let right_row_count = right_stats.num_rows; // calculate global stats - let is_exact = left_stats.is_exact && right_stats.is_exact; - let num_rows = left_stats - .num_rows - .zip(right_stats.num_rows) - .map(|(a, b)| a * b); + let num_rows = left_row_count.multiply(&right_row_count); // the result size is two times a*b because you have the columns of both left and right let total_byte_size = left_stats .total_byte_size - .zip(right_stats.total_byte_size) - .map(|(a, b)| 2 * a * b); - - // calculate column stats - let column_statistics = - // complete the column statistics if they are missing only on one side - match (left_stats.column_statistics, right_stats.column_statistics) { - (None, None) => None, - (None, Some(right_col_stat)) => Some(( - vec![ColumnStatistics::default(); left_col_count], - right_col_stat, - )), - (Some(left_col_stat), None) => Some(( - left_col_stat, - vec![ColumnStatistics::default(); right_col_count], - )), - (Some(left_col_stat), Some(right_col_stat)) => { - Some((left_col_stat, right_col_stat)) - } - } - .map(|(left_col_stats, right_col_stats)| { - // the null counts must be multiplied by the row counts of the other side (if defined) - // Min, max and distinct_count on the other hand are invariants. - left_col_stats.into_iter().map(|s| ColumnStatistics{ - null_count: s.null_count.zip(right_row_count).map(|(a, b)| a * b), - distinct_count: s.distinct_count, - min_value: s.min_value, - max_value: s.max_value, - }).chain( - right_col_stats.into_iter().map(|s| ColumnStatistics{ - null_count: s.null_count.zip(left_row_count).map(|(a, b)| a * b), - distinct_count: s.distinct_count, - min_value: s.min_value, - max_value: s.max_value, - })).collect() - }); + .multiply(&right_stats.total_byte_size) + .multiply(&Precision::Exact(2)); + + let left_col_stats = left_stats.column_statistics; + let right_col_stats = right_stats.column_statistics; + + // the null counts must be multiplied by the row counts of the other side (if defined) + // Min, max and distinct_count on the other hand are invariants. + let cross_join_stats = left_col_stats + .into_iter() + .map(|s| ColumnStatistics { + null_count: s.null_count.multiply(&right_row_count), + distinct_count: s.distinct_count, + min_value: s.min_value, + max_value: s.max_value, + }) + .chain(right_col_stats.into_iter().map(|s| ColumnStatistics { + null_count: s.null_count.multiply(&left_row_count), + distinct_count: s.distinct_count, + min_value: s.min_value, + max_value: s.max_value, + })) + .collect(); Statistics { - is_exact, num_rows, total_byte_size, - column_statistics, + column_statistics: cross_join_stats, } } @@ -365,17 +344,18 @@ fn build_batch( .iter() .map(|arr| { let scalar = ScalarValue::try_from_array(arr, left_index)?; - Ok(scalar.to_array_of_size(batch.num_rows())) + scalar.to_array_of_size(batch.num_rows()) }) .collect::>>()?; - RecordBatch::try_new( + RecordBatch::try_new_with_options( Arc::new(schema.clone()), arrays .iter() .chain(batch.columns().iter()) .cloned() .collect(), + &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), ) .map_err(Into::into) } @@ -459,6 +439,7 @@ mod tests { use super::*; use crate::common; use crate::test::build_table_scan_i32; + use datafusion_common::{assert_batches_sorted_eq, assert_contains}; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; @@ -484,63 +465,60 @@ mod tests { let right_bytes = 27; let left = Statistics { - is_exact: true, - num_rows: Some(left_row_count), - total_byte_size: Some(left_bytes), - column_statistics: Some(vec![ + num_rows: Precision::Exact(left_row_count), + total_byte_size: Precision::Exact(left_bytes), + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(5), - max_value: Some(ScalarValue::Int64(Some(21))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: Some(0), + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(0), }, ColumnStatistics { - distinct_count: Some(1), - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: Some(3), + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Exact(3), }, - ]), + ], }; let right = Statistics { - is_exact: true, - num_rows: Some(right_row_count), - total_byte_size: Some(right_bytes), - column_statistics: Some(vec![ColumnStatistics { - distinct_count: Some(3), - max_value: Some(ScalarValue::Int64(Some(12))), - min_value: Some(ScalarValue::Int64(Some(0))), - null_count: Some(2), - }]), + num_rows: Precision::Exact(right_row_count), + total_byte_size: Precision::Exact(right_bytes), + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int64(Some(12))), + min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + null_count: Precision::Exact(2), + }], }; - let result = stats_cartesian_product(left, 3, right, 2); + let result = stats_cartesian_product(left, right); let expected = Statistics { - is_exact: true, - num_rows: Some(left_row_count * right_row_count), - total_byte_size: Some(2 * left_bytes * right_bytes), - column_statistics: Some(vec![ + num_rows: Precision::Exact(left_row_count * right_row_count), + total_byte_size: Precision::Exact(2 * left_bytes * right_bytes), + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(5), - max_value: Some(ScalarValue::Int64(Some(21))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: Some(0), + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(0), }, ColumnStatistics { - distinct_count: Some(1), - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: Some(3 * right_row_count), + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Exact(3 * right_row_count), }, ColumnStatistics { - distinct_count: Some(3), - max_value: Some(ScalarValue::Int64(Some(12))), - min_value: Some(ScalarValue::Int64(Some(0))), - null_count: Some(2 * left_row_count), + distinct_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int64(Some(12))), + min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + null_count: Precision::Exact(2 * left_row_count), }, - ]), + ], }; assert_eq!(result, expected); @@ -551,63 +529,60 @@ mod tests { let left_row_count = 11; let left = Statistics { - is_exact: true, - num_rows: Some(left_row_count), - total_byte_size: Some(23), - column_statistics: Some(vec![ + num_rows: Precision::Exact(left_row_count), + total_byte_size: Precision::Exact(23), + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(5), - max_value: Some(ScalarValue::Int64(Some(21))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: Some(0), + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(0), }, ColumnStatistics { - distinct_count: Some(1), - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: Some(3), + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Exact(3), }, - ]), + ], }; let right = Statistics { - is_exact: true, - num_rows: None, // not defined! - total_byte_size: None, // not defined! - column_statistics: Some(vec![ColumnStatistics { - distinct_count: Some(3), - max_value: Some(ScalarValue::Int64(Some(12))), - min_value: Some(ScalarValue::Int64(Some(0))), - null_count: Some(2), - }]), + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int64(Some(12))), + min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + null_count: Precision::Exact(2), + }], }; - let result = stats_cartesian_product(left, 3, right, 2); + let result = stats_cartesian_product(left, right); let expected = Statistics { - is_exact: true, - num_rows: None, - total_byte_size: None, - column_statistics: Some(vec![ + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(5), - max_value: Some(ScalarValue::Int64(Some(21))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: None, // we don't know the row count on the right + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Absent, // we don't know the row count on the right }, ColumnStatistics { - distinct_count: Some(1), - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: None, // we don't know the row count on the right + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Absent, // we don't know the row count on the right }, ColumnStatistics { - distinct_count: Some(3), - max_value: Some(ScalarValue::Int64(Some(12))), - min_value: Some(ScalarValue::Int64(Some(0))), - null_count: Some(2 * left_row_count), + distinct_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int64(Some(12))), + min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + null_count: Precision::Exact(2 * left_row_count), }, - ]), + ], }; assert_eq!(result, expected); diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 8e204634f3d9..374a0ad50700 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Defines the join plan for executing partitions in parallel and then joining the results -//! into a set of partitions. +//! [`HashJoinExec`] Partitioned Hash Join Operator use std::fmt; use std::mem::size_of; @@ -26,27 +25,24 @@ use std::{any::Any, usize, vec}; use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, - calculate_join_output_ordering, combine_join_ordering_equivalence_properties, - get_final_indices_from_bit_map, need_produce_result_in_final, JoinSide, + calculate_join_output_ordering, get_final_indices_from_bit_map, + need_produce_result_in_final, JoinHashMap, JoinHashMapType, }; -use crate::DisplayAs; use crate::{ - coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec, expressions::Column, expressions::PhysicalSortExpr, hash_utils::create_hashes, - joins::hash_join_utils::{JoinHashMap, JoinHashMapType}, joins::utils::{ adjust_right_output_partitioning, build_join_schema, check_join_is_valid, - combine_join_equivalence_properties, estimate_join_statistics, - partitioned_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, - JoinFilter, JoinOn, + estimate_join_statistics, partitioned_join_output_partitioning, + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn, StatefulStreamResult, }, metrics::{ExecutionPlanMetricsSet, MetricsSet}, - DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, - PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, + DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, + RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use crate::{handle_state, DisplayAs}; use super::{ utils::{OnceAsync, OnceFut}, @@ -55,50 +51,242 @@ use super::{ use arrow::array::{ Array, ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray, UInt32Array, - UInt32BufferBuilder, UInt64Array, UInt64BufferBuilder, + UInt64Array, }; -use arrow::compute::{and, take, FilterBuilder}; +use arrow::compute::kernels::cmp::{eq, not_distinct}; +use arrow::compute::{and, concat_batches, take, FilterBuilder}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; use arrow_array::cast::downcast_array; use arrow_schema::ArrowError; use datafusion_common::{ - exec_err, internal_err, plan_err, DataFusionError, JoinType, Result, + exec_err, internal_err, plan_err, DataFusionError, JoinSide, JoinType, Result, }; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::OrderingEquivalenceProperties; +use datafusion_physical_expr::equivalence::join_equivalence_properties; +use datafusion_physical_expr::EquivalenceProperties; use ahash::RandomState; -use arrow::compute::kernels::cmp::{eq, not_distinct}; use futures::{ready, Stream, StreamExt, TryStreamExt}; -type JoinLeftData = (JoinHashMap, RecordBatch, MemoryReservation); +/// HashTable and input data for the left (build side) of a join +struct JoinLeftData { + /// The hash table with indices into `batch` + hash_map: JoinHashMap, + /// The input rows for the build side + batch: RecordBatch, + /// Memory reservation that tracks memory used by `hash_map` hash table + /// `batch`. Cleared on drop. + #[allow(dead_code)] + reservation: MemoryReservation, +} + +impl JoinLeftData { + /// Create a new `JoinLeftData` from its parts + fn new( + hash_map: JoinHashMap, + batch: RecordBatch, + reservation: MemoryReservation, + ) -> Self { + Self { + hash_map, + batch, + reservation, + } + } -/// Join execution plan executes partitions in parallel and combines them into a set of -/// partitions. + /// Returns the number of rows in the build side + fn num_rows(&self) -> usize { + self.batch.num_rows() + } + + /// return a reference to the hash map + fn hash_map(&self) -> &JoinHashMap { + &self.hash_map + } + + /// returns a reference to the build side batch + fn batch(&self) -> &RecordBatch { + &self.batch + } +} + +/// Join execution plan: Evaluates eqijoin predicates in parallel on multiple +/// partitions using a hash table and an optional filter list to apply post +/// join. +/// +/// # Join Expressions +/// +/// This implementation is optimized for evaluating eqijoin predicates ( +/// ` = `) expressions, which are represented as a list of `Columns` +/// in [`Self::on`]. +/// +/// Non-equality predicates, which can not pushed down to a join inputs (e.g. +/// ` != `) are known as "filter expressions" and are evaluated +/// after the equijoin predicates. +/// +/// # "Build Side" vs "Probe Side" +/// +/// HashJoin takes two inputs, which are referred to as the "build" and the +/// "probe". The build side is the first child, and the probe side is the second +/// child. +/// +/// The two inputs are treated differently and it is VERY important that the +/// *smaller* input is placed on the build side to minimize the work of creating +/// the hash table. +/// +/// ```text +/// ┌───────────┐ +/// │ HashJoin │ +/// │ │ +/// └───────────┘ +/// │ │ +/// ┌─────┘ └─────┐ +/// ▼ ▼ +/// ┌────────────┐ ┌─────────────┐ +/// │ Input │ │ Input │ +/// │ [0] │ │ [1] │ +/// └────────────┘ └─────────────┘ +/// +/// "build side" "probe side" +/// ``` +/// +/// Execution proceeds in 2 stages: +/// +/// 1. the **build phase** creates a hash table from the tuples of the build side, +/// and single concatenated batch containing data from all fetched record batches. +/// Resulting hash table stores hashed join-key fields for each row as a key, and +/// indices of corresponding rows in concatenated batch. +/// +/// Hash join uses LIFO data structure as a hash table, and in order to retain +/// original build-side input order while obtaining data during probe phase, hash +/// table is updated by iterating batch sequence in reverse order -- it allows to +/// keep rows with smaller indices "on the top" of hash table, and still maintain +/// correct indexing for concatenated build-side data batch. +/// +/// Example of build phase for 3 record batches: +/// +/// +/// ```text +/// +/// Original build-side data Inserting build-side values into hashmap Concatenated build-side batch +/// ┌───────────────────────────┐ +/// hasmap.insert(row-hash, row-idx + offset) │ idx │ +/// ┌───────┐ │ ┌───────┐ │ +/// │ Row 1 │ 1) update_hash for batch 3 with offset 0 │ │ Row 6 │ 0 │ +/// Batch 1 │ │ - hashmap.insert(Row 7, idx 1) │ Batch 3 │ │ │ +/// │ Row 2 │ - hashmap.insert(Row 6, idx 0) │ │ Row 7 │ 1 │ +/// └───────┘ │ └───────┘ │ +/// │ │ +/// ┌───────┐ │ ┌───────┐ │ +/// │ Row 3 │ 2) update_hash for batch 2 with offset 2 │ │ Row 3 │ 2 │ +/// │ │ - hashmap.insert(Row 5, idx 4) │ │ │ │ +/// Batch 2 │ Row 4 │ - hashmap.insert(Row 4, idx 3) │ Batch 2 │ Row 4 │ 3 │ +/// │ │ - hashmap.insert(Row 3, idx 2) │ │ │ │ +/// │ Row 5 │ │ │ Row 5 │ 4 │ +/// └───────┘ │ └───────┘ │ +/// │ │ +/// ┌───────┐ │ ┌───────┐ │ +/// │ Row 6 │ 3) update_hash for batch 1 with offset 5 │ │ Row 1 │ 5 │ +/// Batch 3 │ │ - hashmap.insert(Row 2, idx 5) │ Batch 1 │ │ │ +/// │ Row 7 │ - hashmap.insert(Row 1, idx 6) │ │ Row 2 │ 6 │ +/// └───────┘ │ └───────┘ │ +/// │ │ +/// └───────────────────────────┘ +/// +/// ``` +/// +/// 2. the **probe phase** where the tuples of the probe side are streamed +/// through, checking for matches of the join keys in the hash table. +/// +/// ```text +/// ┌────────────────┐ ┌────────────────┐ +/// │ ┌─────────┐ │ │ ┌─────────┐ │ +/// │ │ Hash │ │ │ │ Hash │ │ +/// │ │ Table │ │ │ │ Table │ │ +/// │ │(keys are│ │ │ │(keys are│ │ +/// │ │equi join│ │ │ │equi join│ │ Stage 2: batches from +/// Stage 1: the │ │columns) │ │ │ │columns) │ │ the probe side are +/// *entire* build │ │ │ │ │ │ │ │ streamed through, and +/// side is read │ └─────────┘ │ │ └─────────┘ │ checked against the +/// into the hash │ ▲ │ │ ▲ │ contents of the hash +/// table │ HashJoin │ │ HashJoin │ table +/// └──────┼─────────┘ └──────────┼─────┘ +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ │ +/// +/// │ │ +/// ┌────────────┐ ┌────────────┐ +/// │RecordBatch │ │RecordBatch │ +/// └────────────┘ └────────────┘ +/// ┌────────────┐ ┌────────────┐ +/// │RecordBatch │ │RecordBatch │ +/// └────────────┘ └────────────┘ +/// ... ... +/// ┌────────────┐ ┌────────────┐ +/// │RecordBatch │ │RecordBatch │ +/// └────────────┘ └────────────┘ +/// +/// build side probe side +/// +/// ``` +/// +/// # Example "Optimal" Plans /// -/// Filter expression expected to contain non-equality predicates that can not be pushed -/// down to any of join inputs. -/// In case of outer join, filter applied to only matched rows. +/// The differences in the inputs means that for classic "Star Schema Query", +/// the optimal plan will be a **"Right Deep Tree"** . A Star Schema Query is +/// one where there is one large table and several smaller "dimension" tables, +/// joined on `Foreign Key = Primary Key` predicates. +/// +/// A "Right Deep Tree" looks like this large table as the probe side on the +/// lowest join: +/// +/// ```text +/// ┌───────────┐ +/// │ HashJoin │ +/// │ │ +/// └───────────┘ +/// │ │ +/// ┌───────┘ └──────────┐ +/// ▼ ▼ +/// ┌───────────────┐ ┌───────────┐ +/// │ small table 1 │ │ HashJoin │ +/// │ "dimension" │ │ │ +/// └───────────────┘ └───┬───┬───┘ +/// ┌──────────┘ └───────┐ +/// │ │ +/// ▼ ▼ +/// ┌───────────────┐ ┌───────────┐ +/// │ small table 2 │ │ HashJoin │ +/// │ "dimension" │ │ │ +/// └───────────────┘ └───┬───┬───┘ +/// ┌────────┘ └────────┐ +/// │ │ +/// ▼ ▼ +/// ┌───────────────┐ ┌───────────────┐ +/// │ small table 3 │ │ large table │ +/// │ "dimension" │ │ "fact" │ +/// └───────────────┘ └───────────────┘ +/// ``` #[derive(Debug)] pub struct HashJoinExec { /// left (build) side which gets hashed pub left: Arc, /// right (probe) side which are filtered by the hash table pub right: Arc, - /// Set of common columns used to join on + /// Set of equijoin columns from the relations: `(left_col, right_col)` pub on: Vec<(Column, Column)>, /// Filters which are applied while finding matching rows pub filter: Option, - /// How the join is performed + /// How the join is performed (`OUTER`, `INNER`, etc) pub join_type: JoinType, - /// The schema once the join is applied + /// The output schema for the join schema: SchemaRef, - /// Build-side data + /// Future that consumes left input and builds the hash table left_fut: OnceAsync, - /// Shares the `RandomState` for the hashing algorithm + /// Shared the `RandomState` for the hashing algorithm random_state: RandomState, /// Output order output_order: Option>, @@ -108,12 +296,16 @@ pub struct HashJoinExec { metrics: ExecutionPlanMetricsSet, /// Information of index and left / right placement of columns column_indices: Vec, - /// If null_equals_null is true, null == null else null != null + /// Null matching behavior: If `null_equals_null` is true, rows that have + /// `null`s in both left and right equijoin columns will be matched. + /// Otherwise, rows that have `null`s in the join columns will not be + /// matched and thus will not appear in the output. pub null_equals_null: bool, } impl HashJoinExec { /// Tries to create a new [HashJoinExec]. + /// /// # Error /// This function errors when it is not possible to join the left and right sides on keys `on`. pub fn try_new( @@ -146,7 +338,7 @@ impl HashJoinExec { left_schema.fields.len(), &Self::maintains_input_order(*join_type), Some(Self::probe_side()), - )?; + ); Ok(HashJoinExec { left, @@ -366,28 +558,15 @@ impl ExecutionPlan for HashJoinExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - let left_columns_len = self.left.schema().fields.len(); - combine_join_equivalence_properties( - self.join_type, + join_equivalence_properties( self.left.equivalence_properties(), self.right.equivalence_properties(), - left_columns_len, - self.on(), - self.schema(), - ) - } - - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - combine_join_ordering_equivalence_properties( &self.join_type, - &self.left, - &self.right, self.schema(), &self.maintains_input_order(), Some(Self::probe_side()), - self.equivalence_properties(), + self.on(), ) - .unwrap() } fn children(&self) -> Vec> { @@ -418,6 +597,7 @@ impl ExecutionPlan for HashJoinExec { let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); let left_partitions = self.left.output_partitioning().partition_count(); let right_partitions = self.right.output_partitioning().partition_count(); + if self.mode == PartitionMode::Partitioned && left_partitions != right_partitions { return internal_err!( @@ -477,15 +657,14 @@ impl ExecutionPlan for HashJoinExec { on_right, filter: self.filter.clone(), join_type: self.join_type, - left_fut, - visited_left_side: None, right: right_stream, column_indices: self.column_indices.clone(), random_state: self.random_state.clone(), join_metrics, null_equals_null: self.null_equals_null, - is_exhausted: false, reservation, + state: HashJoinStreamState::WaitBuildSide, + build_side: BuildSide::Initial(BuildSideInitialState { left_fut }), })) } @@ -493,7 +672,7 @@ impl ExecutionPlan for HashJoinExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { // TODO stats: it is not possible in general to know the output size of joins // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` @@ -502,10 +681,13 @@ impl ExecutionPlan for HashJoinExec { self.right.clone(), self.on.clone(), &self.join_type, + &self.schema, ) } } +/// Reads the left (build) side of the input, buffering it in memory, to build a +/// hash table (`LeftJoinData`) async fn collect_left_input( partition: Option, random_state: RandomState, @@ -519,16 +701,10 @@ async fn collect_left_input( let (left_input, left_input_partition) = if let Some(partition) = partition { (left, partition) + } else if left.output_partitioning().partition_count() != 1 { + (Arc::new(CoalescePartitionsExec::new(left)) as _, 0) } else { - let merge = { - if left.output_partitioning().partition_count() != 1 { - Arc::new(CoalescePartitionsExec::new(left)) - } else { - left - } - }; - - (merge, 0) + (left, 0) }; // Depending on partition argument load single partition or whole left side in memory @@ -578,7 +754,10 @@ async fn collect_left_input( let mut hashmap = JoinHashMap::with_capacity(num_rows); let mut hashes_buffer = Vec::new(); let mut offset = 0; - for batch in batches.iter() { + + // Updating hashmap starting from the last batch + let batches_iter = batches.iter().rev(); + for batch in batches_iter.clone() { hashes_buffer.clear(); hashes_buffer.resize(batch.num_rows(), 0); update_hash( @@ -589,18 +768,25 @@ async fn collect_left_input( &random_state, &mut hashes_buffer, 0, + true, )?; offset += batch.num_rows(); } // Merge all batches into a single batch, so we // can directly index into the arrays - let single_batch = concat_batches(&schema, &batches, num_rows)?; + let single_batch = concat_batches(&schema, batches_iter)?; + let data = JoinLeftData::new(hashmap, single_batch, reservation); - Ok((hashmap, single_batch, reservation)) + Ok(data) } -/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, -/// assuming that the [RecordBatch] corresponds to the `index`th +/// Updates `hash_map` with new entries from `batch` evaluated against the expressions `on` +/// using `offset` as a start value for `batch` row indices. +/// +/// `fifo_hashmap` sets the order of iteration over `batch` rows while updating hashmap, +/// which allows to keep either first (if set to true) or last (if set to false) row index +/// as a chain head for rows with equal hash values. +#[allow(clippy::too_many_arguments)] pub fn update_hash( on: &[Column], batch: &RecordBatch, @@ -609,6 +795,7 @@ pub fn update_hash( random_state: &RandomState, hashes_buffer: &mut Vec, deleted_offset: usize, + fifo_hashmap: bool, ) -> Result<()> where T: JoinHashMapType, @@ -616,7 +803,7 @@ where // evaluate the keys let keys_values = on .iter() - .map(|c| Ok(c.evaluate(batch)?.into_array(batch.num_rows()))) + .map(|c| c.evaluate(batch)?.into_array(batch.num_rows())) .collect::>>()?; // calculate the hash values @@ -625,53 +812,142 @@ where // For usual JoinHashmap, the implementation is void. hash_map.extend_zero(batch.num_rows()); - // insert hashes to key of the hashmap - let (mut_map, mut_list) = hash_map.get_mut(); - for (row, hash_value) in hash_values.iter().enumerate() { - let item = mut_map.get_mut(*hash_value, |(hash, _)| *hash_value == *hash); - if let Some((_, index)) = item { - // Already exists: add index to next array - let prev_index = *index; - // Store new value inside hashmap - *index = (row + offset + 1) as u64; - // Update chained Vec at row + offset with previous value - mut_list[row + offset - deleted_offset] = prev_index; - } else { - mut_map.insert( - *hash_value, - // store the value + 1 as 0 value reserved for end of list - (*hash_value, (row + offset + 1) as u64), - |(hash, _)| *hash, - ); - // chained list at (row + offset) is already initialized with 0 - // meaning end of list - } + // Updating JoinHashMap from hash values iterator + let hash_values_iter = hash_values + .iter() + .enumerate() + .map(|(i, val)| (i + offset, val)); + + if fifo_hashmap { + hash_map.update_from_iter(hash_values_iter.rev(), deleted_offset); + } else { + hash_map.update_from_iter(hash_values_iter, deleted_offset); } + Ok(()) } -/// A stream that issues [RecordBatch]es as they arrive from the right of the join. +/// Represents build-side of hash join. +enum BuildSide { + /// Indicates that build-side not collected yet + Initial(BuildSideInitialState), + /// Indicates that build-side data has been collected + Ready(BuildSideReadyState), +} + +/// Container for BuildSide::Initial related data +struct BuildSideInitialState { + /// Future for building hash table from build-side input + left_fut: OnceFut, +} + +/// Container for BuildSide::Ready related data +struct BuildSideReadyState { + /// Collected build-side data + left_data: Arc, + /// Which build-side rows have been matched while creating output. + /// For some OUTER joins, we need to know which rows have not been matched + /// to produce the correct output. + visited_left_side: BooleanBufferBuilder, +} + +impl BuildSide { + /// Tries to extract BuildSideInitialState from BuildSide enum. + /// Returns an error if state is not Initial. + fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> { + match self { + BuildSide::Initial(state) => Ok(state), + _ => internal_err!("Expected build side in initial state"), + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + fn try_as_ready(&self) -> Result<&BuildSideReadyState> { + match self { + BuildSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> { + match self { + BuildSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } +} + +/// Represents state of HashJoinStream +/// +/// Expected state transitions performed by HashJoinStream are: +/// +/// ```text +/// +/// WaitBuildSide +/// │ +/// ▼ +/// ┌─► FetchProbeBatch ───► ExhaustedProbeSide ───► Completed +/// │ │ +/// │ ▼ +/// └─ ProcessProbeBatch +/// +/// ``` +enum HashJoinStreamState { + /// Initial state for HashJoinStream indicating that build-side data not collected yet + WaitBuildSide, + /// Indicates that build-side has been collected, and stream is ready for fetching probe-side + FetchProbeBatch, + /// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed + ProcessProbeBatch(ProcessProbeBatchState), + /// Indicates that probe-side has been fully processed + ExhaustedProbeSide, + /// Indicates that HashJoinStream execution is completed + Completed, +} + +/// Container for HashJoinStreamState::ProcessProbeBatch related data +struct ProcessProbeBatchState { + /// Current probe-side batch + batch: RecordBatch, +} + +impl HashJoinStreamState { + /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum. + /// Returns an error if state is not ProcessProbeBatchState. + fn try_as_process_probe_batch(&self) -> Result<&ProcessProbeBatchState> { + match self { + HashJoinStreamState::ProcessProbeBatch(state) => Ok(state), + _ => internal_err!("Expected hash join stream in ProcessProbeBatch state"), + } + } +} + +/// [`Stream`] for [`HashJoinExec`] that does the actual join. +/// +/// This stream: +/// +/// 1. Reads the entire left input (build) and constructs a hash table +/// +/// 2. Streams [RecordBatch]es as they arrive from the right input (probe) and joins +/// them with the contents of the hash table struct HashJoinStream { /// Input schema schema: Arc, - /// columns from the left + /// equijoin columns from the left (build side) on_left: Vec, - /// columns from the right used to compute the hash + /// equijoin columns from the right (probe side) on_right: Vec, - /// join filter + /// optional join filter filter: Option, - /// type of the join + /// type of the join (left, right, semi, etc) join_type: JoinType, - /// future for data from left side - left_fut: OnceFut, - /// Keeps track of the left side rows whether they are visited - visited_left_side: Option, - /// right + /// right (probe) input right: SendableRecordBatchStream, /// Random state used for hashing initialization random_state: RandomState, - /// There is nothing to process anymore and left side is processed in case of left join - is_exhausted: bool, /// Metrics join_metrics: BuildProbeJoinMetrics, /// Information of index and left / right placement of columns @@ -680,6 +956,10 @@ struct HashJoinStream { null_equals_null: bool, /// Memory reservation reservation: MemoryReservation, + /// State of the stream + state: HashJoinStreamState, + /// Build side + build_side: BuildSide, } impl RecordBatchStream for HashJoinStream { @@ -688,37 +968,51 @@ impl RecordBatchStream for HashJoinStream { } } -// Returns build/probe indices satisfying the equality condition. -// On LEFT.b1 = RIGHT.b2 -// LEFT Table: -// a1 b1 c1 -// 1 1 10 -// 3 3 30 -// 5 5 50 -// 7 7 70 -// 9 8 90 -// 11 8 110 -// 13 10 130 -// RIGHT Table: -// a2 b2 c2 -// 2 2 20 -// 4 4 40 -// 6 6 60 -// 8 8 80 -// 10 10 100 -// 12 10 120 -// The result is -// "+----+----+-----+----+----+-----+", -// "| a1 | b1 | c1 | a2 | b2 | c2 |", -// "+----+----+-----+----+----+-----+", -// "| 9 | 8 | 90 | 8 | 8 | 80 |", -// "| 11 | 8 | 110 | 8 | 8 | 80 |", -// "| 13 | 10 | 130 | 10 | 10 | 100 |", -// "| 13 | 10 | 130 | 12 | 10 | 120 |", -// "+----+----+-----+----+----+-----+" -// And the result of build and probe indices are: -// Build indices: 4, 5, 6, 6 -// Probe indices: 3, 3, 4, 5 +/// Returns build/probe indices satisfying the equality condition. +/// +/// # Example +/// +/// For `LEFT.b1 = RIGHT.b2`: +/// LEFT (build) Table: +/// ```text +/// a1 b1 c1 +/// 1 1 10 +/// 3 3 30 +/// 5 5 50 +/// 7 7 70 +/// 9 8 90 +/// 11 8 110 +/// 13 10 130 +/// ``` +/// +/// RIGHT (probe) Table: +/// ```text +/// a2 b2 c2 +/// 2 2 20 +/// 4 4 40 +/// 6 6 60 +/// 8 8 80 +/// 10 10 100 +/// 12 10 120 +/// ``` +/// +/// The result is +/// ```text +/// "+----+----+-----+----+----+-----+", +/// "| a1 | b1 | c1 | a2 | b2 | c2 |", +/// "+----+----+-----+----+----+-----+", +/// "| 9 | 8 | 90 | 8 | 8 | 80 |", +/// "| 11 | 8 | 110 | 8 | 8 | 80 |", +/// "| 13 | 10 | 130 | 10 | 10 | 100 |", +/// "| 13 | 10 | 130 | 12 | 10 | 120 |", +/// "+----+----+-----+----+----+-----+" +/// ``` +/// +/// And the result of build and probe indices are: +/// ```text +/// Build indices: 4, 5, 6, 6 +/// Probe indices: 3, 3, 4, 5 +/// ``` #[allow(clippy::too_many_arguments)] pub fn build_equal_condition_join_indices( build_hashmap: &T, @@ -732,25 +1026,25 @@ pub fn build_equal_condition_join_indices( filter: Option<&JoinFilter>, build_side: JoinSide, deleted_offset: Option, + fifo_hashmap: bool, ) -> Result<(UInt64Array, UInt32Array)> { let keys_values = probe_on .iter() - .map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))) + .map(|c| c.evaluate(probe_batch)?.into_array(probe_batch.num_rows())) .collect::>>()?; let build_join_values = build_on .iter() .map(|c| { - Ok(c.evaluate(build_input_buffer)? - .into_array(build_input_buffer.num_rows())) + c.evaluate(build_input_buffer)? + .into_array(build_input_buffer.num_rows()) }) .collect::>>()?; hashes_buffer.clear(); hashes_buffer.resize(probe_batch.num_rows(), 0); let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - // Using a buffer builder to avoid slower normal builder - let mut build_indices = UInt64BufferBuilder::new(0); - let mut probe_indices = UInt32BufferBuilder::new(0); - // The chained list algorithm generates build indices for each probe row in a reversed sequence as such: + + // In case build-side input has not been inverted while JoinHashMap creation, the chained list algorithm + // will return build indices for each probe row in a reverse order as such: // Build Indices: [5, 4, 3] // Probe Indices: [1, 1, 1] // @@ -779,44 +1073,17 @@ pub fn build_equal_condition_join_indices( // (5,1) // // With this approach, the lexicographic order on both the probe side and the build side is preserved. - let hash_map = build_hashmap.get_map(); - let next_chain = build_hashmap.get_list(); - for (row, hash_value) in hash_values.iter().enumerate().rev() { - // Get the hash and find it in the build index - - // For every item on the build and probe we check if it matches - // This possibly contains rows with hash collisions, - // So we have to check here whether rows are equal or not - if let Some((_, index)) = - hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) - { - let mut i = *index - 1; - loop { - let build_row_value = if let Some(offset) = deleted_offset { - // This arguments means that we prune the next index way before here. - if i < offset as u64 { - // End of the list due to pruning - break; - } - i - offset as u64 - } else { - i - }; - build_indices.append(build_row_value); - probe_indices.append(row as u32); - // Follow the chain to get the next index value - let next = next_chain[build_row_value as usize]; - if next == 0 { - // end of list - break; - } - i = next - 1; - } - } - } - // Reversing both sets of indices - build_indices.as_slice_mut().reverse(); - probe_indices.as_slice_mut().reverse(); + let (mut probe_indices, mut build_indices) = if fifo_hashmap { + build_hashmap.get_matched_indices(hash_values.iter().enumerate(), deleted_offset) + } else { + let (mut matched_probe, mut matched_build) = build_hashmap + .get_matched_indices(hash_values.iter().enumerate().rev(), deleted_offset); + + matched_probe.as_slice_mut().reverse(); + matched_build.as_slice_mut().reverse(); + + (matched_probe, matched_build) + }; let left: UInt64Array = PrimitiveArray::new(build_indices.finish().into(), None); let right: UInt32Array = PrimitiveArray::new(probe_indices.finish().into(), None); @@ -905,142 +1172,213 @@ impl HashJoinStream { &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>> { + loop { + return match self.state { + HashJoinStreamState::WaitBuildSide => { + handle_state!(ready!(self.collect_build_side(cx))) + } + HashJoinStreamState::FetchProbeBatch => { + handle_state!(ready!(self.fetch_probe_batch(cx))) + } + HashJoinStreamState::ProcessProbeBatch(_) => { + handle_state!(self.process_probe_batch()) + } + HashJoinStreamState::ExhaustedProbeSide => { + handle_state!(self.process_unmatched_build_batch()) + } + HashJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + /// Collects build-side data by polling `OnceFut` future from initialized build-side + /// + /// Updates build-side to `Ready`, and state to `FetchProbeSide` + fn collect_build_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); - let left_data = match ready!(self.left_fut.get(cx)) { - Ok(left_data) => left_data, - Err(e) => return Poll::Ready(Some(Err(e))), - }; + // build hash table from left (build) side, if not yet done + let left_data = ready!(self + .build_side + .try_as_initial_mut()? + .left_fut + .get_shared(cx))?; build_timer.done(); - // Reserving memory for visited_left_side bitmap in case it hasn't been initialied yet + // Reserving memory for visited_left_side bitmap in case it hasn't been initialized yet // and join_type requires to store it - if self.visited_left_side.is_none() - && need_produce_result_in_final(self.join_type) - { + if need_produce_result_in_final(self.join_type) { // TODO: Replace `ceil` wrapper with stable `div_cell` after // https://github.com/rust-lang/rust/issues/88581 - let visited_bitmap_size = bit_util::ceil(left_data.1.num_rows(), 8); + let visited_bitmap_size = bit_util::ceil(left_data.num_rows(), 8); self.reservation.try_grow(visited_bitmap_size)?; self.join_metrics.build_mem_used.add(visited_bitmap_size); } - let visited_left_side = self.visited_left_side.get_or_insert_with(|| { - let num_rows = left_data.1.num_rows(); - if need_produce_result_in_final(self.join_type) { - // these join type need the bitmap to identify which row has be matched or unmatched. - // For the `left semi` join, need to use the bitmap to produce the matched row in the left side - // For the `left` join, need to use the bitmap to produce the unmatched row in the left side with null - // For the `left anti` join, need to use the bitmap to produce the unmatched row in the left side - // For the `full` join, need to use the bitmap to produce the unmatched row in the left side with null - let mut buffer = BooleanBufferBuilder::new(num_rows); - buffer.append_n(num_rows, false); - buffer - } else { - BooleanBufferBuilder::new(0) - } + let visited_left_side = if need_produce_result_in_final(self.join_type) { + let num_rows = left_data.num_rows(); + // Some join types need to track which row has be matched or unmatched: + // `left semi` join: need to use the bitmap to produce the matched row in the left side + // `left` join: need to use the bitmap to produce the unmatched row in the left side with null + // `left anti` join: need to use the bitmap to produce the unmatched row in the left side + // `full` join: need to use the bitmap to produce the unmatched row in the left side with null + let mut buffer = BooleanBufferBuilder::new(num_rows); + buffer.append_n(num_rows, false); + buffer + } else { + BooleanBufferBuilder::new(0) + }; + + self.state = HashJoinStreamState::FetchProbeBatch; + self.build_side = BuildSide::Ready(BuildSideReadyState { + left_data, + visited_left_side, }); + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Fetches next batch from probe-side + /// + /// If non-empty batch has been fetched, updates state to `ProcessProbeBatchState`, + /// otherwise updates state to `ExhaustedProbeSide` + fn fetch_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.right.poll_next_unpin(cx)) { + None => { + self.state = HashJoinStreamState::ExhaustedProbeSide; + } + Some(Ok(batch)) => { + self.state = + HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { + batch, + }); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Joins current probe batch with build-side data and produces batch with matched output + /// + /// Updates state to `FetchProbeBatch` + fn process_probe_batch( + &mut self, + ) -> Result>> { + let state = self.state.try_as_process_probe_batch()?; + let build_side = self.build_side.try_as_ready_mut()?; + + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(state.batch.num_rows()); + let timer = self.join_metrics.join_time.timer(); + let mut hashes_buffer = vec![]; - self.right - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - // one right batch in the join loop - Some(Ok(batch)) => { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - // get the matched two indices for the on condition - let left_right_indices = build_equal_condition_join_indices( - &left_data.0, - &left_data.1, - &batch, - &self.on_left, - &self.on_right, - &self.random_state, - self.null_equals_null, - &mut hashes_buffer, - self.filter.as_ref(), - JoinSide::Left, - None, - ); - - let result = match left_right_indices { - Ok((left_side, right_side)) => { - // set the left bitmap - // and only left, full, left semi, left anti need the left bitmap - if need_produce_result_in_final(self.join_type) { - left_side.iter().flatten().for_each(|x| { - visited_left_side.set_bit(x as usize, true); - }); - } - - // adjust the two side indices base on the join type - let (left_side, right_side) = adjust_indices_by_join_type( - left_side, - right_side, - batch.num_rows(), - self.join_type, - ); - - let result = build_batch_from_indices( - &self.schema, - &left_data.1, - &batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - Some(result) - } - Err(err) => Some(exec_err!( - "Fail to build join indices in HashJoinExec, error:{err}" - )), - }; - timer.done(); - result - } - None => { - let timer = self.join_metrics.join_time.timer(); - if need_produce_result_in_final(self.join_type) && !self.is_exhausted - { - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = get_final_indices_from_bit_map( - visited_left_side, - self.join_type, - ); - let empty_right_batch = - RecordBatch::new_empty(self.right.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - &left_data.1, - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - - if let Ok(ref batch) = result { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - timer.done(); - self.is_exhausted = true; - Some(result) - } else { - // end of the join loop - None - } + // get the matched two indices for the on condition + let left_right_indices = build_equal_condition_join_indices( + build_side.left_data.hash_map(), + build_side.left_data.batch(), + &state.batch, + &self.on_left, + &self.on_right, + &self.random_state, + self.null_equals_null, + &mut hashes_buffer, + self.filter.as_ref(), + JoinSide::Left, + None, + true, + ); + + let result = match left_right_indices { + Ok((left_side, right_side)) => { + // set the left bitmap + // and only left, full, left semi, left anti need the left bitmap + if need_produce_result_in_final(self.join_type) { + left_side.iter().flatten().for_each(|x| { + build_side.visited_left_side.set_bit(x as usize, true); + }); } - Some(err) => Some(err), - }) + + // adjust the two side indices base on the join type + let (left_side, right_side) = adjust_indices_by_join_type( + left_side, + right_side, + state.batch.num_rows(), + self.join_type, + ); + + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &state.batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(state.batch.num_rows()); + result + } + Err(err) => { + exec_err!("Fail to build join indices in HashJoinExec, error:{err}") + } + }; + timer.done(); + + self.state = HashJoinStreamState::FetchProbeBatch; + + Ok(StatefulStreamResult::Ready(Some(result?))) + } + + /// Processes unmatched build-side rows for certain join types and produces output batch + /// + /// Updates state to `Completed` + fn process_unmatched_build_batch( + &mut self, + ) -> Result>> { + let timer = self.join_metrics.join_time.timer(); + + if !need_produce_result_in_final(self.join_type) { + self.state = HashJoinStreamState::Completed; + + return Ok(StatefulStreamResult::Continue); + } + + let build_side = self.build_side.try_as_ready()?; + + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_bit_map(&build_side.visited_left_side, self.join_type); + let empty_right_batch = RecordBatch::new_empty(self.right.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + + if let Ok(ref batch) = result { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + timer.done(); + + self.state = HashJoinStreamState::Completed; + + Ok(StatefulStreamResult::Ready(Some(result?))) } } @@ -1059,29 +1397,24 @@ impl Stream for HashJoinStream { mod tests { use std::sync::Arc; + use super::*; + use crate::{ + common, expressions::Column, hash_utils::create_hashes, + joins::hash_join::build_equal_condition_join_indices, memory::MemoryExec, + repartition::RepartitionExec, test::build_table_i32, test::exec::MockExec, + }; + use arrow::array::{ArrayRef, Date32Array, Int32Array, UInt32Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema}; - - use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; - use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::Literal; - use hashbrown::raw::RawTable; - - use crate::{ - common, - expressions::Column, - hash_utils::create_hashes, - joins::{hash_join::build_equal_condition_join_indices, utils::JoinSide}, - memory::MemoryExec, - repartition::RepartitionExec, - test::build_table_i32, - test::exec::MockExec, + use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, assert_contains, ScalarValue, }; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; - use datafusion_physical_expr::expressions::BinaryExpr; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; - use super::*; + use hashbrown::raw::RawTable; fn build_table( a: (&str, &Vec), @@ -1240,7 +1573,9 @@ mod tests { "| 3 | 5 | 9 | 20 | 5 | 80 |", "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1322,7 +1657,48 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_inner_one_randomly_ordered() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = build_table( + ("a1", &vec![0, 3, 2, 1]), + ("b1", &vec![4, 5, 5, 4]), + ("c1", &vec![6, 9, 8, 7]), + ); + let right = build_table( + ("a2", &vec![20, 30, 10]), + ("b2", &vec![5, 6, 4]), + ("c2", &vec![80, 90, 70]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let expected = [ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 0 | 4 | 6 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "+----+----+----+----+----+----+", + ]; + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1368,7 +1744,8 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1422,7 +1799,58 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_inner_one_two_parts_left_randomly_ordered() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let batch1 = build_table_i32( + ("a1", &vec![0, 3]), + ("b1", &vec![4, 5]), + ("c1", &vec![6, 9]), + ); + let batch2 = build_table_i32( + ("a1", &vec![2, 1]), + ("b1", &vec![5, 4]), + ("c1", &vec![8, 7]), + ); + let schema = batch1.schema(); + + let left = Arc::new( + MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), + ); + let right = build_table( + ("a2", &vec![20, 30, 10]), + ("b2", &vec![5, 6, 4]), + ("c2", &vec![80, 90, 70]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let expected = [ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 0 | 4 | 6 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "+----+----+----+----+----+----+", + ]; + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1471,7 +1899,9 @@ mod tests { "| 1 | 4 | 7 | 10 | 4 | 70 |", "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); // second part let stream = join.execute(1, task_ctx.clone())?; @@ -1486,7 +1916,8 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1910,12 +2341,14 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", - "| 12 | 10 | 40 |", "| 8 | 8 | 20 |", + "| 12 | 10 | 40 |", + "| 10 | 10 | 100 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightSemi join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1970,12 +2403,14 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", - "| 12 | 10 | 40 |", "| 8 | 8 | 20 |", + "| 12 | 10 | 40 |", + "| 10 | 10 | 100 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightSemi join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); // left_table right semi join right_table on left_table.b1 = right_table.b2 on left_table.a1!=9 let filter_expression = Arc::new(BinaryExpr::new( @@ -1996,11 +2431,13 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", "| 12 | 10 | 40 |", + "| 10 | 10 | 100 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightSemi join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2153,12 +2590,14 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", + "| 6 | 6 | 60 |", "| 2 | 2 | 80 |", "| 4 | 4 | 120 |", - "| 6 | 6 | 60 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightAnti join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2211,14 +2650,16 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", "| 12 | 10 | 40 |", + "| 6 | 6 | 60 |", "| 2 | 2 | 80 |", + "| 10 | 10 | 100 |", "| 4 | 4 | 120 |", - "| 6 | 6 | 60 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightAnti join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); // left_table right anti join right_table on left_table.b1 = right_table.b2 and right_table.b2!=8 let column_indices = vec![ColumnIndex { @@ -2247,13 +2688,15 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", + "| 8 | 8 | 20 |", + "| 6 | 6 | 60 |", "| 2 | 2 | 80 |", "| 4 | 4 | 120 |", - "| 6 | 6 | 60 |", - "| 8 | 8 | 20 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightAnti join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2402,16 +2845,11 @@ mod tests { ("c", &vec![30, 40]), ); - let left_data = ( - JoinHashMap { - map: hashmap_left, - next, - }, - left, - ); + let join_hash_map = JoinHashMap::new(hashmap_left, next); + let (l, r) = build_equal_condition_join_indices( - &left_data.0, - &left_data.1, + &join_hash_map, + &left, &right, &[Column::new("a", 0)], &[Column::new("a", 0)], @@ -2421,6 +2859,7 @@ mod tests { None, JoinSide::Left, None, + false, )?; let mut left_ids = UInt64Builder::with_capacity(0); diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 19f10d06e1ef..6ddf19c51193 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -25,9 +25,9 @@ pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; mod hash_join; -mod hash_join_utils; mod nested_loop_join; mod sort_merge_join; +mod stream_join_utils; mod symmetric_hash_join; pub mod utils; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index c49c16dba313..6951642ff801 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -19,39 +19,39 @@ //! The nested loop join can execute in parallel by partitions and it is //! determined by the [`JoinType`]. +use std::any::Any; +use std::fmt::Formatter; +use std::sync::Arc; +use std::task::Poll; + +use crate::coalesce_batches::concat_batches; use crate::joins::utils::{ append_right_indices, apply_join_filter_to_indices, build_batch_from_indices, - build_join_schema, check_join_is_valid, combine_join_equivalence_properties, - estimate_join_statistics, get_anti_indices, get_anti_u64_indices, - get_final_indices_from_bit_map, get_semi_indices, get_semi_u64_indices, - partitioned_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, - JoinSide, OnceAsync, OnceFut, + build_join_schema, check_join_is_valid, estimate_join_statistics, get_anti_indices, + get_anti_u64_indices, get_final_indices_from_bit_map, get_semi_indices, + get_semi_u64_indices, partitioned_join_output_partitioning, BuildProbeJoinMetrics, + ColumnIndex, JoinFilter, OnceAsync, OnceFut, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, }; + use arrow::array::{ BooleanBufferBuilder, UInt32Array, UInt32Builder, UInt64Array, UInt64Builder, }; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; -use datafusion_common::{exec_err, DataFusionError, Statistics}; -use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_common::{exec_err, DataFusionError, JoinSide, Result, Statistics}; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::TaskContext; use datafusion_expr::JoinType; +use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortExpr}; -use futures::{ready, Stream, StreamExt, TryStreamExt}; -use std::any::Any; -use std::fmt::Formatter; -use std::sync::Arc; -use std::task::Poll; -use crate::coalesce_batches::concat_batches; -use datafusion_common::Result; -use datafusion_execution::memory_pool::MemoryConsumer; -use datafusion_execution::TaskContext; +use futures::{ready, Stream, StreamExt, TryStreamExt}; /// Data of the inner table side type JoinLeftData = (RecordBatch, MemoryReservation); @@ -192,14 +192,15 @@ impl ExecutionPlan for NestedLoopJoinExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - let left_columns_len = self.left.schema().fields.len(); - combine_join_equivalence_properties( - self.join_type, + join_equivalence_properties( self.left.equivalence_properties(), self.right.equivalence_properties(), - left_columns_len, - &[], // empty join keys + &self.join_type, self.schema(), + &self.maintains_input_order(), + None, + // No on columns in nested loop join + &[], ) } @@ -282,12 +283,13 @@ impl ExecutionPlan for NestedLoopJoinExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { estimate_join_statistics( self.left.clone(), self.right.clone(), vec![], &self.join_type, + &self.schema, ) } } @@ -739,21 +741,20 @@ impl RecordBatchStream for NestedLoopJoinStream { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::{ common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, test::build_table_i32, }; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::BinaryExpr; - - use crate::joins::utils::JoinSide; - use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; - use datafusion_physical_expr::expressions::Literal; + use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; use datafusion_physical_expr::PhysicalExpr; - use std::sync::Arc; fn build_table( a: (&str, &Vec), diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 4de723ab73ea..f6fdc6d77c0c 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -30,18 +30,15 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::expressions::Column; -use crate::expressions::PhysicalSortExpr; +use crate::expressions::{Column, PhysicalSortExpr}; use crate::joins::utils::{ build_join_schema, calculate_join_output_ordering, check_join_is_valid, - combine_join_equivalence_properties, combine_join_ordering_equivalence_properties, - estimate_join_statistics, partitioned_join_output_partitioning, JoinOn, JoinSide, + estimate_join_statistics, partitioned_join_output_partitioning, JoinOn, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use crate::{ - metrics, DisplayAs, DisplayFormatType, Distribution, EquivalenceProperties, - ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, - SendableRecordBatchStream, Statistics, + metrics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; use arrow::array::*; @@ -50,11 +47,12 @@ use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; use datafusion_common::{ - internal_err, not_impl_err, plan_err, DataFusionError, JoinType, Result, + internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, }; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{OrderingEquivalenceProperties, PhysicalSortRequirement}; +use datafusion_physical_expr::equivalence::join_equivalence_properties; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement}; use futures::{Stream, StreamExt}; @@ -141,7 +139,7 @@ impl SortMergeJoinExec { left_schema.fields.len(), &Self::maintains_input_order(join_type), Some(Self::probe_side(&join_type)), - )?; + ); let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); @@ -284,28 +282,15 @@ impl ExecutionPlan for SortMergeJoinExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - let left_columns_len = self.left.schema().fields.len(); - combine_join_equivalence_properties( - self.join_type, + join_equivalence_properties( self.left.equivalence_properties(), self.right.equivalence_properties(), - left_columns_len, - self.on(), - self.schema(), - ) - } - - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - combine_join_ordering_equivalence_properties( &self.join_type, - &self.left, - &self.right, self.schema(), &self.maintains_input_order(), Some(Self::probe_side(&self.join_type)), - self.equivalence_properties(), + self.on(), ) - .unwrap() } fn children(&self) -> Vec> { @@ -381,7 +366,7 @@ impl ExecutionPlan for SortMergeJoinExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { // TODO stats: it is not possible in general to know the output size of joins // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` @@ -390,6 +375,7 @@ impl ExecutionPlan for SortMergeJoinExec { self.right.clone(), self.on.clone(), &self.join_type, + &self.schema, ) } } @@ -1397,24 +1383,23 @@ fn is_join_arrays_equal( mod tests { use std::sync::Arc; - use arrow::array::{Date32Array, Date64Array, Int32Array}; - use arrow::compute::SortOptions; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion_execution::config::SessionConfig; - use datafusion_execution::TaskContext; - use crate::expressions::Column; use crate::joins::utils::JoinOn; use crate::joins::SortMergeJoinExec; use crate::memory::MemoryExec; use crate::test::build_table_i32; use crate::{common, ExecutionPlan}; - use datafusion_common::Result; + + use arrow::array::{Date32Array, Date64Array, Int32Array}; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, + assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, }; + use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use datafusion_execution::TaskContext; fn build_table( a: (&str, &Vec), diff --git a/datafusion/physical-plan/src/joins/hash_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs similarity index 60% rename from datafusion/physical-plan/src/joins/hash_join_utils.rs rename to datafusion/physical-plan/src/joins/stream_join_utils.rs index 525c1a7145b9..9a4c98927683 100644 --- a/datafusion/physical-plan/src/joins/hash_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -15,137 +15,39 @@ // specific language governing permissions and limitations // under the License. -//! This file contains common subroutines for regular and symmetric hash join +//! This file contains common subroutines for symmetric hash join //! related functionality, used both in join calculations and optimization rules. use std::collections::{HashMap, VecDeque}; -use std::fmt::Debug; -use std::ops::IndexMut; use std::sync::Arc; -use std::{fmt, usize}; +use std::task::{Context, Poll}; +use std::usize; -use crate::joins::utils::{JoinFilter, JoinSide}; +use crate::joins::utils::{JoinFilter, JoinHashMapType, StatefulStreamResult}; +use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; +use crate::{handle_async_state, handle_state, metrics, ExecutionPlan}; use arrow::compute::concat_batches; -use arrow::datatypes::{ArrowNativeType, SchemaRef}; -use arrow_array::builder::BooleanBufferBuilder; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; +use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; +use arrow_schema::{Schema, SchemaRef}; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result, + ScalarValue, +}; +use datafusion_execution::SendableRecordBatchStream; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::intervals::{Interval, IntervalBound}; +use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use async_trait::async_trait; +use futures::{ready, FutureExt, StreamExt}; use hashbrown::raw::RawTable; use hashbrown::HashSet; -// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. -// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, -// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. -// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 -// As the key is a hash value, we need to check possible hash collisions in the probe stage -// During this stage it might be the case that a row is contained the same hashmap value, -// but the values don't match. Those are checked in the [equal_rows] macro -// The indices (values) are stored in a separate chained list stored in the `Vec`. -// The first value (+1) is stored in the hashmap, whereas the next value is stored in array at the position value. -// The chain can be followed until the value "0" has been reached, meaning the end of the list. -// Also see chapter 5.3 of [Balancing vectorized query execution with bandwidth-optimized storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487) -// See the example below: -// Insert (1,1) -// map: -// --------- -// | 1 | 2 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 0 | 0 | -// --------------------- -// Insert (2,2) -// map: -// --------- -// | 1 | 2 | -// | 2 | 3 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 0 | 0 | -// --------------------- -// Insert (1,3) -// map: -// --------- -// | 1 | 4 | -// | 2 | 3 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 2 | 0 | <--- hash value 1 maps to 4,2 (which means indices values 3,1) -// --------------------- -// Insert (1,4) -// map: -// --------- -// | 1 | 5 | -// | 2 | 3 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 2 | 4 | <--- hash value 1 maps to 5,4,2 (which means indices values 4,3,1) -// --------------------- -// TODO: speed up collision checks -// https://github.com/apache/arrow-datafusion/issues/50 -pub struct JoinHashMap { - // Stores hash value to last row index - pub map: RawTable<(u64, u64)>, - // Stores indices in chained list data structure - pub next: Vec, -} - -impl JoinHashMap { - pub(crate) fn with_capacity(capacity: usize) -> Self { - JoinHashMap { - map: RawTable::with_capacity(capacity), - next: vec![0; capacity], - } - } -} - -/// Trait defining methods that must be implemented by a hash map type to be used for joins. -pub trait JoinHashMapType { - /// The type of list used to store the hash values. - type NextType: IndexMut; - /// Extend with zero - fn extend_zero(&mut self, len: usize); - /// Returns mutable references to the hash map and the next. - fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType); - /// Returns a reference to the hash map. - fn get_map(&self) -> &RawTable<(u64, u64)>; - /// Returns a reference to the next. - fn get_list(&self) -> &Self::NextType; -} - -/// Implementation of `JoinHashMapType` for `JoinHashMap`. -impl JoinHashMapType for JoinHashMap { - type NextType = Vec; - - // Void implementation - fn extend_zero(&mut self, _: usize) {} - - /// Get mutable references to the hash map and the next. - fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType) { - (&mut self.map, &mut self.next) - } - - /// Get a reference to the hash map. - fn get_map(&self) -> &RawTable<(u64, u64)> { - &self.map - } - - /// Get a reference to the next. - fn get_list(&self) -> &Self::NextType { - &self.next - } -} - /// Implementation of `JoinHashMapType` for `PruningJoinHashMap`. impl JoinHashMapType for PruningJoinHashMap { type NextType = VecDeque; @@ -171,12 +73,6 @@ impl JoinHashMapType for PruningJoinHashMap { } } -impl fmt::Debug for JoinHashMap { - fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { - Ok(()) - } -} - /// The `PruningJoinHashMap` is similar to a regular `JoinHashMap`, but with /// the capability of pruning elements in an efficient manner. This structure /// is particularly useful for cases where it's necessary to remove elements @@ -188,15 +84,15 @@ impl fmt::Debug for JoinHashMap { /// Let's continue the example of `JoinHashMap` and then show how `PruningJoinHashMap` would /// handle the pruning scenario. /// -/// Insert the pair (1,4) into the `PruningJoinHashMap`: +/// Insert the pair (10,4) into the `PruningJoinHashMap`: /// map: -/// --------- -/// | 1 | 5 | -/// | 2 | 3 | -/// --------- +/// ---------- +/// | 10 | 5 | +/// | 20 | 3 | +/// ---------- /// list: /// --------------------- -/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 1 maps to 5,4,2 (which means indices values 4,3,1) +/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means indices values 4,3,1) /// --------------------- /// /// Now, let's prune 3 rows from `PruningJoinHashMap`: @@ -206,7 +102,7 @@ impl fmt::Debug for JoinHashMap { /// --------- /// list: /// --------- -/// | 2 | 4 | <--- hash value 1 maps to 2 (5 - 3), 1 (4 - 3), NA (2 - 3) (which means indices values 1,0) +/// | 2 | 4 | <--- hash value 10 maps to 2 (5 - 3), 1 (4 - 3), NA (2 - 3) (which means indices values 1,0) /// --------- /// /// After pruning, the | 2 | 3 | entry is deleted from `PruningJoinHashMap` since @@ -281,7 +177,7 @@ impl PruningJoinHashMap { prune_length: usize, deleting_offset: u64, shrink_factor: usize, - ) -> Result<()> { + ) { // Remove elements from the list based on the pruning length. self.next.drain(0..prune_length); @@ -304,7 +200,6 @@ impl PruningJoinHashMap { // Shrink the map if necessary. self.shrink_if_necessary(shrink_factor); - Ok(()) } } @@ -333,7 +228,7 @@ pub fn map_origin_col_to_filter_col( side: &JoinSide, ) -> Result> { let filter_schema = filter.schema(); - let mut col_to_col_map: HashMap = HashMap::new(); + let mut col_to_col_map = HashMap::::new(); for (filter_schema_index, index) in filter.column_indices().iter().enumerate() { if index.side.eq(side) { // Get the main field from column index: @@ -425,7 +320,11 @@ pub fn build_filter_input_order( order: &PhysicalSortExpr, ) -> Result> { let opt_expr = convert_sort_expr_with_filter_schema(&side, filter, schema, order)?; - Ok(opt_expr.map(|filter_expr| SortedFilterExpr::new(order.clone(), filter_expr))) + opt_expr + .map(|filter_expr| { + SortedFilterExpr::try_new(order.clone(), filter_expr, filter.schema()) + }) + .transpose() } /// Convert a physical expression into a filter expression using the given @@ -468,16 +367,18 @@ pub struct SortedFilterExpr { impl SortedFilterExpr { /// Constructor - pub fn new( + pub fn try_new( origin_sorted_expr: PhysicalSortExpr, filter_expr: Arc, - ) -> Self { - Self { + filter_schema: &Schema, + ) -> Result { + let dt = &filter_expr.data_type(filter_schema)?; + Ok(Self { origin_sorted_expr, filter_expr, - interval: Interval::default(), + interval: Interval::make_unbounded(dt)?, node_index: 0, - } + }) } /// Get origin expr information pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr { @@ -599,16 +500,16 @@ pub fn update_filter_expr_interval( .origin_sorted_expr() .expr .evaluate(batch)? - .into_array(1); + .into_array(1)?; // Convert the array to a ScalarValue: let value = ScalarValue::try_from_array(&array, 0)?; // Create a ScalarValue representing positive or negative infinity for the same data type: - let unbounded = IntervalBound::make_unbounded(value.data_type())?; + let inf = ScalarValue::try_from(value.data_type())?; // Update the interval with lower and upper bounds based on the sort option: let interval = if sorted_expr.origin_sorted_expr().options.descending { - Interval::new(unbounded, IntervalBound::new(value, false)) + Interval::try_new(inf, value)? } else { - Interval::new(IntervalBound::new(value, false), unbounded) + Interval::try_new(value, inf)? }; // Set the calculated interval for the sorted filter expression: sorted_expr.set_interval(interval); @@ -681,7 +582,7 @@ where // get the semi index (0..prune_length) .filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) - .collect::>() + .collect() } pub fn combine_two_batches( @@ -697,7 +598,7 @@ pub fn combine_two_batches( (Some(left_batch), Some(right_batch)) => { // If both batches are present, concatenate them: concat_batches(output_schema, &[left_batch, right_batch]) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) .map(Some) } (None, None) => { @@ -726,68 +627,516 @@ pub fn record_visited_indices( } } +/// Represents the various states of an eager join stream operation. +/// +/// This enum is used to track the current state of streaming during a join +/// operation. It provides indicators as to which side of the join needs to be +/// pulled next or if one (or both) sides have been exhausted. This allows +/// for efficient management of resources and optimal performance during the +/// join process. +#[derive(Clone, Debug)] +pub enum EagerJoinStreamState { + /// Indicates that the next step should pull from the right side of the join. + PullRight, + + /// Indicates that the next step should pull from the left side of the join. + PullLeft, + + /// State representing that the right side of the join has been fully processed. + RightExhausted, + + /// State representing that the left side of the join has been fully processed. + LeftExhausted, + + /// Represents a state where both sides of the join are exhausted. + /// + /// The `final_result` field indicates whether the join operation has + /// produced a final result or not. + BothExhausted { final_result: bool }, +} + +/// `EagerJoinStream` is an asynchronous trait designed for managing incremental +/// join operations between two streams, such as those used in `SymmetricHashJoinExec` +/// and `SortMergeJoinExec`. Unlike traditional join approaches that need to scan +/// one side of the join fully before proceeding, `EagerJoinStream` facilitates +/// more dynamic join operations by working with streams as they emit data. This +/// approach allows for more efficient processing, particularly in scenarios +/// where waiting for complete data materialization is not feasible or optimal. +/// The trait provides a framework for handling various states of such a join +/// process, ensuring that join logic is efficiently executed as data becomes +/// available from either stream. +/// +/// Implementors of this trait can perform eager joins of data from two different +/// asynchronous streams, typically referred to as left and right streams. The +/// trait provides a comprehensive set of methods to control and execute the join +/// process, leveraging the states defined in `EagerJoinStreamState`. Methods are +/// primarily focused on asynchronously fetching data batches from each stream, +/// processing them, and managing transitions between various states of the join. +/// +/// This trait's default implementations use a state machine approach to navigate +/// different stages of the join operation, handling data from both streams and +/// determining when the join completes. +/// +/// State Transitions: +/// - From `PullLeft` to `PullRight` or `LeftExhausted`: +/// - In `fetch_next_from_left_stream`, when fetching a batch from the left stream: +/// - On success (`Some(Ok(batch))`), state transitions to `PullRight` for +/// processing the batch. +/// - On error (`Some(Err(e))`), the error is returned, and the state remains +/// unchanged. +/// - On no data (`None`), state changes to `LeftExhausted`, returning `Continue` +/// to proceed with the join process. +/// - From `PullRight` to `PullLeft` or `RightExhausted`: +/// - In `fetch_next_from_right_stream`, when fetching from the right stream: +/// - If a batch is available, state changes to `PullLeft` for processing. +/// - On error, the error is returned without changing the state. +/// - If right stream is exhausted (`None`), state transitions to `RightExhausted`, +/// with a `Continue` result. +/// - Handling `RightExhausted` and `LeftExhausted`: +/// - Methods `handle_right_stream_end` and `handle_left_stream_end` manage scenarios +/// when streams are exhausted: +/// - They attempt to continue processing with the other stream. +/// - If both streams are exhausted, state changes to `BothExhausted { final_result: false }`. +/// - Transition to `BothExhausted { final_result: true }`: +/// - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are +/// exhausted, indicating completion of processing and availability of final results. +#[async_trait] +pub trait EagerJoinStream { + /// Implements the main polling logic for the join stream. + /// + /// This method continuously checks the state of the join stream and + /// acts accordingly by delegating the handling to appropriate sub-methods + /// depending on the current state. + /// + /// # Arguments + /// + /// * `cx` - A context that facilitates cooperative non-blocking execution within a task. + /// + /// # Returns + /// + /// * `Poll>>` - A polled result, either a `RecordBatch` or None. + fn poll_next_impl( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> + where + Self: Send, + { + loop { + return match self.state() { + EagerJoinStreamState::PullRight => { + handle_async_state!(self.fetch_next_from_right_stream(), cx) + } + EagerJoinStreamState::PullLeft => { + handle_async_state!(self.fetch_next_from_left_stream(), cx) + } + EagerJoinStreamState::RightExhausted => { + handle_async_state!(self.handle_right_stream_end(), cx) + } + EagerJoinStreamState::LeftExhausted => { + handle_async_state!(self.handle_left_stream_end(), cx) + } + EagerJoinStreamState::BothExhausted { + final_result: false, + } => { + handle_state!(self.prepare_for_final_results_after_exhaustion()) + } + EagerJoinStreamState::BothExhausted { final_result: true } => { + Poll::Ready(None) + } + }; + } + } + /// Asynchronously pulls the next batch from the right stream. + /// + /// This default implementation checks for the next value in the right stream. + /// If a batch is found, the state is switched to `PullLeft`, and the batch handling + /// is delegated to `process_batch_from_right`. If the stream ends, the state is set to `RightExhausted`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after pulling the batch. + async fn fetch_next_from_right_stream( + &mut self, + ) -> Result>> { + match self.right_stream().next().await { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StatefulStreamResult::Continue); + } + self.set_state(EagerJoinStreamState::PullLeft); + self.process_batch_from_right(batch) + } + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::RightExhausted); + Ok(StatefulStreamResult::Continue) + } + } + } + + /// Asynchronously pulls the next batch from the left stream. + /// + /// This default implementation checks for the next value in the left stream. + /// If a batch is found, the state is switched to `PullRight`, and the batch handling + /// is delegated to `process_batch_from_left`. If the stream ends, the state is set to `LeftExhausted`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after pulling the batch. + async fn fetch_next_from_left_stream( + &mut self, + ) -> Result>> { + match self.left_stream().next().await { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StatefulStreamResult::Continue); + } + self.set_state(EagerJoinStreamState::PullRight); + self.process_batch_from_left(batch) + } + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::LeftExhausted); + Ok(StatefulStreamResult::Continue) + } + } + } + + /// Asynchronously handles the scenario when the right stream is exhausted. + /// + /// In this default implementation, when the right stream is exhausted, it attempts + /// to pull from the left stream. If a batch is found in the left stream, it delegates + /// the handling to `process_batch_from_left`. If both streams are exhausted, the state is set + /// to indicate both streams are exhausted without final results yet. + /// + /// # Returns + /// + /// * `Result>>` - The state result after checking the exhaustion state. + async fn handle_right_stream_end( + &mut self, + ) -> Result>> { + match self.left_stream().next().await { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StatefulStreamResult::Continue); + } + self.process_batch_after_right_end(batch) + } + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::BothExhausted { + final_result: false, + }); + Ok(StatefulStreamResult::Continue) + } + } + } + + /// Asynchronously handles the scenario when the left stream is exhausted. + /// + /// When the left stream is exhausted, this default + /// implementation tries to pull from the right stream and delegates the batch + /// handling to `process_batch_after_left_end`. If both streams are exhausted, the state + /// is updated to indicate so. + /// + /// # Returns + /// + /// * `Result>>` - The state result after checking the exhaustion state. + async fn handle_left_stream_end( + &mut self, + ) -> Result>> { + match self.right_stream().next().await { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StatefulStreamResult::Continue); + } + self.process_batch_after_left_end(batch) + } + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::BothExhausted { + final_result: false, + }); + Ok(StatefulStreamResult::Continue) + } + } + } + + /// Handles the state when both streams are exhausted and final results are yet to be produced. + /// + /// This default implementation switches the state to indicate both streams are + /// exhausted with final results and then invokes the handling for this specific + /// scenario via `process_batches_before_finalization`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after both streams are exhausted. + fn prepare_for_final_results_after_exhaustion( + &mut self, + ) -> Result>> { + self.set_state(EagerJoinStreamState::BothExhausted { final_result: true }); + self.process_batches_before_finalization() + } + + /// Handles a pulled batch from the right stream. + /// + /// # Arguments + /// + /// * `batch` - The pulled `RecordBatch` from the right stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after processing the batch. + fn process_batch_from_right( + &mut self, + batch: RecordBatch, + ) -> Result>>; + + /// Handles a pulled batch from the left stream. + /// + /// # Arguments + /// + /// * `batch` - The pulled `RecordBatch` from the left stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after processing the batch. + fn process_batch_from_left( + &mut self, + batch: RecordBatch, + ) -> Result>>; + + /// Handles the situation when only the left stream is exhausted. + /// + /// # Arguments + /// + /// * `right_batch` - The `RecordBatch` from the right stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after the left stream is exhausted. + fn process_batch_after_left_end( + &mut self, + right_batch: RecordBatch, + ) -> Result>>; + + /// Handles the situation when only the right stream is exhausted. + /// + /// # Arguments + /// + /// * `left_batch` - The `RecordBatch` from the left stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after the right stream is exhausted. + fn process_batch_after_right_end( + &mut self, + left_batch: RecordBatch, + ) -> Result>>; + + /// Handles the final state after both streams are exhausted. + /// + /// # Returns + /// + /// * `Result>>` - The final state result after processing. + fn process_batches_before_finalization( + &mut self, + ) -> Result>>; + + /// Provides mutable access to the right stream. + /// + /// # Returns + /// + /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to the right stream. + fn right_stream(&mut self) -> &mut SendableRecordBatchStream; + + /// Provides mutable access to the left stream. + /// + /// # Returns + /// + /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to the left stream. + fn left_stream(&mut self) -> &mut SendableRecordBatchStream; + + /// Sets the current state of the join stream. + /// + /// # Arguments + /// + /// * `state` - The new state to be set. + fn set_state(&mut self, state: EagerJoinStreamState); + + /// Fetches the current state of the join stream. + /// + /// # Returns + /// + /// * `EagerJoinStreamState` - The current state of the join stream. + fn state(&mut self) -> EagerJoinStreamState; +} + +#[derive(Debug)] +pub struct StreamJoinSideMetrics { + /// Number of batches consumed by this operator + pub(crate) input_batches: metrics::Count, + /// Number of rows consumed by this operator + pub(crate) input_rows: metrics::Count, +} + +/// Metrics for HashJoinExec +#[derive(Debug)] +pub struct StreamJoinMetrics { + /// Number of left batches/rows consumed by this operator + pub(crate) left: StreamJoinSideMetrics, + /// Number of right batches/rows consumed by this operator + pub(crate) right: StreamJoinSideMetrics, + /// Memory used by sides in bytes + pub(crate) stream_memory_usage: metrics::Gauge, + /// Number of batches produced by this operator + pub(crate) output_batches: metrics::Count, + /// Number of rows produced by this operator + pub(crate) output_rows: metrics::Count, +} + +impl StreamJoinMetrics { + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let left = StreamJoinSideMetrics { + input_batches, + input_rows, + }; + + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let right = StreamJoinSideMetrics { + input_batches, + input_rows, + }; + + let stream_memory_usage = + MetricBuilder::new(metrics).gauge("stream_memory_usage", partition); + + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + left, + right, + output_batches, + stream_memory_usage, + output_rows, + } + } +} + +/// Updates sorted filter expressions with corresponding node indices from the +/// expression interval graph. +/// +/// This function iterates through the provided sorted filter expressions, +/// gathers the corresponding node indices from the expression interval graph, +/// and then updates the sorted expressions with these indices. It ensures +/// that these sorted expressions are aligned with the structure of the graph. +fn update_sorted_exprs_with_node_indices( + graph: &mut ExprIntervalGraph, + sorted_exprs: &mut [SortedFilterExpr], +) { + // Extract filter expressions from the sorted expressions: + let filter_exprs = sorted_exprs + .iter() + .map(|expr| expr.filter_expr().clone()) + .collect::>(); + + // Gather corresponding node indices for the extracted filter expressions from the graph: + let child_node_indices = graph.gather_node_indices(&filter_exprs); + + // Iterate through the sorted expressions and the gathered node indices: + for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices) { + // Update each sorted expression with the corresponding node index: + sorted_expr.set_node_index(index); + } +} + +/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. +/// +/// # Arguments +/// +/// * `filter` - The join filter to base the sorting on. +/// * `left` - The left execution plan. +/// * `right` - The right execution plan. +/// * `left_sort_exprs` - The expressions to sort on the left side. +/// * `right_sort_exprs` - The expressions to sort on the right side. +/// +/// # Returns +/// +/// * A tuple consisting of the sorted filter expression for the left and right sides, and an expression interval graph. +pub fn prepare_sorted_exprs( + filter: &JoinFilter, + left: &Arc, + right: &Arc, + left_sort_exprs: &[PhysicalSortExpr], + right_sort_exprs: &[PhysicalSortExpr], +) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { + // Build the filter order for the left side + let err = || plan_datafusion_err!("Filter does not include the child order"); + + let left_temp_sorted_filter_expr = build_filter_input_order( + JoinSide::Left, + filter, + &left.schema(), + &left_sort_exprs[0], + )? + .ok_or_else(err)?; + + // Build the filter order for the right side + let right_temp_sorted_filter_expr = build_filter_input_order( + JoinSide::Right, + filter, + &right.schema(), + &right_sort_exprs[0], + )? + .ok_or_else(err)?; + + // Collect the sorted expressions + let mut sorted_exprs = + vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr]; + + // Build the expression interval graph + let mut graph = + ExprIntervalGraph::try_new(filter.expression().clone(), filter.schema())?; + + // Update sorted expressions with node indices + update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs); + + // Swap and remove to get the final sorted filter expressions + let right_sorted_filter_expr = sorted_exprs.swap_remove(1); + let left_sorted_filter_expr = sorted_exprs.swap_remove(0); + + Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph)) +} + #[cfg(test)] pub mod tests { + use std::sync::Arc; + use super::*; + use crate::joins::stream_join_utils::{ + build_filter_input_order, check_filter_expr_contains_sort_information, + convert_sort_expr_with_filter_schema, PruningJoinHashMap, + }; use crate::{ - expressions::Column, - expressions::PhysicalSortExpr, - joins::utils::{ColumnIndex, JoinFilter, JoinSide}, + expressions::{Column, PhysicalSortExpr}, + joins::test_utils::complicated_filter, + joins::utils::{ColumnIndex, JoinFilter}, }; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ScalarValue; + use datafusion_common::JoinSide; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{binary, cast, col, lit}; - use std::sync::Arc; - - /// Filter expr for a + b > c + 10 AND a + b < c + 100 - pub(crate) fn complicated_filter( - filter_schema: &Schema, - ) -> Result> { - let left_expr = binary( - cast( - binary( - col("0", filter_schema)?, - Operator::Plus, - col("1", filter_schema)?, - filter_schema, - )?, - filter_schema, - DataType::Int64, - )?, - Operator::Gt, - binary( - cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, - Operator::Plus, - lit(ScalarValue::Int64(Some(10))), - filter_schema, - )?, - filter_schema, - )?; - - let right_expr = binary( - cast( - binary( - col("0", filter_schema)?, - Operator::Plus, - col("1", filter_schema)?, - filter_schema, - )?, - filter_schema, - DataType::Int64, - )?, - Operator::Lt, - binary( - cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, - Operator::Plus, - lit(ScalarValue::Int64(Some(100))), - filter_schema, - )?, - filter_schema, - )?; - binary(left_expr, Operator::And, right_expr, filter_schema) - } + use datafusion_physical_expr::expressions::{binary, cast, col}; #[test] fn test_column_exchange() -> Result<()> { diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index e6eb5dd69582..2d38c2bd16c3 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -18,43 +18,39 @@ //! This file implements the symmetric hash join algorithm with range-based //! data pruning to join two (potentially infinite) streams. //! -//! A [SymmetricHashJoinExec] plan takes two children plan (with appropriate +//! A [`SymmetricHashJoinExec`] plan takes two children plan (with appropriate //! output ordering) and produces the join output according to the given join //! type and other options. //! -//! This plan uses the [OneSideHashJoiner] object to facilitate join calculations +//! This plan uses the [`OneSideHashJoiner`] object to facilitate join calculations //! for both its children. -use std::fmt; -use std::fmt::Debug; +use std::any::Any; +use std::fmt::{self, Debug}; use std::sync::Arc; use std::task::Poll; -use std::vec; -use std::{any::Any, usize}; +use std::{usize, vec}; use crate::common::SharedMemoryReservation; use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash}; -use crate::joins::hash_join_utils::{ +use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, convert_sort_expr_with_filter_schema, get_pruning_anti_indices, - get_pruning_semi_indices, record_visited_indices, PruningJoinHashMap, + get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices, + EagerJoinStream, EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, + StreamJoinMetrics, +}; +use crate::joins::utils::{ + build_batch_from_indices, build_join_schema, check_join_is_valid, + partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn, + StatefulStreamResult, }; -use crate::joins::StreamJoinPartitionMode; -use crate::DisplayAs; use crate::{ - expressions::Column, - expressions::PhysicalSortExpr, - joins::{ - hash_join_utils::SortedFilterExpr, - utils::{ - build_batch_from_indices, build_join_schema, check_join_is_valid, - combine_join_equivalence_properties, partitioned_join_output_partitioning, - ColumnIndex, JoinFilter, JoinOn, JoinSide, - }, - }, - metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, - DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, Statistics, + expressions::{Column, PhysicalSortExpr}, + joins::StreamJoinPartitionMode, + metrics::{ExecutionPlanMetricsSet, MetricsSet}, + DisplayAs, DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, + Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; use arrow::array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder}; @@ -62,16 +58,17 @@ use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::utils::bisect; -use datafusion_common::{internal_err, plan_err, JoinType}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{ + internal_err, plan_err, DataFusionError, JoinSide, JoinType, Result, +}; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; -use datafusion_physical_expr::intervals::ExprIntervalGraph; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_physical_expr::equivalence::join_equivalence_properties; +use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; -use crate::joins::utils::prepare_sorted_exprs; use ahash::RandomState; -use futures::stream::{select, BoxStream}; -use futures::{Stream, StreamExt}; +use futures::Stream; use hashbrown::HashSet; use parking_lot::Mutex; @@ -188,65 +185,6 @@ pub struct SymmetricHashJoinExec { mode: StreamJoinPartitionMode, } -#[derive(Debug)] -struct SymmetricHashJoinSideMetrics { - /// Number of batches consumed by this operator - input_batches: metrics::Count, - /// Number of rows consumed by this operator - input_rows: metrics::Count, -} - -/// Metrics for HashJoinExec -#[derive(Debug)] -struct SymmetricHashJoinMetrics { - /// Number of left batches/rows consumed by this operator - left: SymmetricHashJoinSideMetrics, - /// Number of right batches/rows consumed by this operator - right: SymmetricHashJoinSideMetrics, - /// Memory used by sides in bytes - pub(crate) stream_memory_usage: metrics::Gauge, - /// Number of batches produced by this operator - output_batches: metrics::Count, - /// Number of rows produced by this operator - output_rows: metrics::Count, -} - -impl SymmetricHashJoinMetrics { - pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); - let left = SymmetricHashJoinSideMetrics { - input_batches, - input_rows, - }; - - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); - let right = SymmetricHashJoinSideMetrics { - input_batches, - input_rows, - }; - - let stream_memory_usage = - MetricBuilder::new(metrics).gauge("stream_memory_usage", partition); - - let output_batches = - MetricBuilder::new(metrics).counter("output_batches", partition); - - let output_rows = MetricBuilder::new(metrics).output_rows(partition); - - Self { - left, - right, - output_batches, - stream_memory_usage, - output_rows, - } - } -} - impl SymmetricHashJoinExec { /// Tries to create a new [SymmetricHashJoinExec]. /// # Error @@ -328,6 +266,11 @@ impl SymmetricHashJoinExec { self.null_equals_null } + /// Get partition mode + pub fn partition_mode(&self) -> StreamJoinPartitionMode { + self.mode + } + /// Check if order information covers every column in the filter expression. pub fn check_if_order_information_available(&self) -> Result { if let Some(filter) = self.filter() { @@ -433,14 +376,15 @@ impl ExecutionPlan for SymmetricHashJoinExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - let left_columns_len = self.left.schema().fields.len(); - combine_join_equivalence_properties( - self.join_type, + join_equivalence_properties( self.left.equivalence_properties(), self.right.equivalence_properties(), - left_columns_len, - self.on(), + &self.join_type, self.schema(), + &self.maintains_input_order(), + // Has alternating probe side + None, + self.on(), ) } @@ -467,9 +411,9 @@ impl ExecutionPlan for SymmetricHashJoinExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { // TODO stats: it is not possible in general to know the output size of joins - Statistics::default() + Ok(Statistics::new_unknown(&self.schema())) } fn execute( @@ -513,21 +457,9 @@ impl ExecutionPlan for SymmetricHashJoinExec { let right_side_joiner = OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema()); - let left_stream = self - .left - .execute(partition, context.clone())? - .map(|val| (JoinSide::Left, val)); - - let right_stream = self - .right - .execute(partition, context.clone())? - .map(|val| (JoinSide::Right, val)); - // This function will attempt to pull items from both streams. - // Each stream will be polled in a round-robin fashion, and whenever a stream is - // ready to yield an item that item is yielded. - // After one of the two input streams completes, the remaining one will be polled exclusively. - // The returned stream completes when both input streams have completed. - let input_stream = select(left_stream, right_stream).boxed(); + let left_stream = self.left.execute(partition, context.clone())?; + + let right_stream = self.right.execute(partition, context.clone())?; let reservation = Arc::new(Mutex::new( MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]")) @@ -538,7 +470,8 @@ impl ExecutionPlan for SymmetricHashJoinExec { } Ok(Box::pin(SymmetricHashJoinStream { - input_stream, + left_stream, + right_stream, schema: self.schema(), filter: self.filter.clone(), join_type: self.join_type, @@ -546,12 +479,12 @@ impl ExecutionPlan for SymmetricHashJoinExec { left: left_side_joiner, right: right_side_joiner, column_indices: self.column_indices.clone(), - metrics: SymmetricHashJoinMetrics::new(partition, &self.metrics), + metrics: StreamJoinMetrics::new(partition, &self.metrics), graph, left_sorted_filter_expr, right_sorted_filter_expr, null_equals_null: self.null_equals_null, - final_result: false, + state: EagerJoinStreamState::PullRight, reservation, })) } @@ -559,8 +492,9 @@ impl ExecutionPlan for SymmetricHashJoinExec { /// A stream that issues [RecordBatch]es as they arrive from the right of the join. struct SymmetricHashJoinStream { - /// Input stream - input_stream: BoxStream<'static, (JoinSide, Result)>, + /// Input streams + left_stream: SendableRecordBatchStream, + right_stream: SendableRecordBatchStream, /// Input schema schema: Arc, /// join filter @@ -584,11 +518,11 @@ struct SymmetricHashJoinStream { /// If null_equals_null is true, null == null else null != null null_equals_null: bool, /// Metrics - metrics: SymmetricHashJoinMetrics, + metrics: StreamJoinMetrics, /// Memory reservation reservation: SharedMemoryReservation, - /// Flag indicating whether there is nothing to process anymore - final_result: bool, + /// State machine for input execution + state: EagerJoinStreamState, } impl RecordBatchStream for SymmetricHashJoinStream { @@ -623,7 +557,9 @@ impl Stream for SymmetricHashJoinStream { /// # Returns /// /// A [Result] object that contains the pruning length. The function will return -/// an error if there is an issue evaluating the build side filter expression. +/// an error if +/// - there is an issue evaluating the build side filter expression; +/// - there is an issue converting the build side filter expression into an array fn determine_prune_length( buffer: &RecordBatch, build_side_filter_expr: &SortedFilterExpr, @@ -634,13 +570,13 @@ fn determine_prune_length( let batch_arr = origin_sorted_expr .expr .evaluate(buffer)? - .into_array(buffer.num_rows()); + .into_array(buffer.num_rows())?; // Get the lower or upper interval based on the sort direction let target = if origin_sorted_expr.options.descending { - interval.upper.value.clone() + interval.upper().clone() } else { - interval.lower.value.clone() + interval.lower().clone() }; // Perform binary search on the array to determine the length of the record batch to be pruned @@ -758,7 +694,9 @@ pub(crate) fn build_side_determined_results( column_indices: &[ColumnIndex], ) -> Result> { // Check if we need to produce a result in the final output: - if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) { + if prune_length > 0 + && need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) + { // Calculate the indices for build and probe sides based on join type and build side: let (build_indices, probe_indices) = calculate_indices_by_join_type( build_hash_joiner.build_side, @@ -833,6 +771,7 @@ pub(crate) fn join_with_probe_batch( filter, build_hash_joiner.build_side, Some(build_hash_joiner.deleted_offset), + false, )?; if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) { record_visited_indices( @@ -945,31 +884,22 @@ impl OneSideHashJoiner { random_state, &mut self.hashes_buffer, self.deleted_offset, + false, )?; Ok(()) } - /// Prunes the internal buffer. - /// - /// Argument `probe_batch` is used to update the intervals of the sorted - /// filter expressions. The updated build interval determines the new length - /// of the build side. If there are rows to prune, they are removed from the - /// internal buffer. + /// Calculate prune length. /// /// # Arguments /// - /// * `schema` - The schema of the final output record batch - /// * `probe_batch` - Incoming RecordBatch of the probe side. + /// * `build_side_sorted_filter_expr` - Build side mutable sorted filter expression.. /// * `probe_side_sorted_filter_expr` - Probe side mutable sorted filter expression. - /// * `join_type` - The type of join (e.g. inner, left, right, etc.). - /// * `column_indices` - A vector of column indices that specifies which columns from the - /// build side should be included in the output. /// * `graph` - A mutable reference to the physical expression graph. /// /// # Returns /// - /// If there are rows to prune, returns the pruned build side record batch wrapped in an `Ok` variant. - /// Otherwise, returns `Ok(None)`. + /// A Result object that contains the pruning length. pub(crate) fn calculate_prune_length_with_probe_batch( &mut self, build_side_sorted_filter_expr: &mut SortedFilterExpr, @@ -990,7 +920,7 @@ impl OneSideHashJoiner { filter_intervals.push((expr.node_index(), expr.interval().clone())) } // Update the physical expression graph using the join filter intervals: - graph.update_ranges(&mut filter_intervals)?; + graph.update_ranges(&mut filter_intervals, Interval::CERTAINLY_TRUE)?; // Extract the new join filter interval for the build side: let calculated_build_side_interval = filter_intervals.remove(0).1; // If the intervals have not changed, return early without pruning: @@ -1009,7 +939,7 @@ impl OneSideHashJoiner { prune_length, self.deleted_offset as u64, HASHMAP_SHRINK_SCALE_FACTOR, - )?; + ); // Remove pruned rows from the visited rows set: for row in self.deleted_offset..(self.deleted_offset + prune_length) { self.visited_rows.remove(&row); @@ -1024,10 +954,104 @@ impl OneSideHashJoiner { } } +impl EagerJoinStream for SymmetricHashJoinStream { + fn process_batch_from_right( + &mut self, + batch: RecordBatch, + ) -> Result>> { + self.perform_join_for_given_side(batch, JoinSide::Right) + .map(|maybe_batch| { + if maybe_batch.is_some() { + StatefulStreamResult::Ready(maybe_batch) + } else { + StatefulStreamResult::Continue + } + }) + } + + fn process_batch_from_left( + &mut self, + batch: RecordBatch, + ) -> Result>> { + self.perform_join_for_given_side(batch, JoinSide::Left) + .map(|maybe_batch| { + if maybe_batch.is_some() { + StatefulStreamResult::Ready(maybe_batch) + } else { + StatefulStreamResult::Continue + } + }) + } + + fn process_batch_after_left_end( + &mut self, + right_batch: RecordBatch, + ) -> Result>> { + self.process_batch_from_right(right_batch) + } + + fn process_batch_after_right_end( + &mut self, + left_batch: RecordBatch, + ) -> Result>> { + self.process_batch_from_left(left_batch) + } + + fn process_batches_before_finalization( + &mut self, + ) -> Result>> { + // Get the left side results: + let left_result = build_side_determined_results( + &self.left, + &self.schema, + self.left.input_buffer.num_rows(), + self.right.input_buffer.schema(), + self.join_type, + &self.column_indices, + )?; + // Get the right side results: + let right_result = build_side_determined_results( + &self.right, + &self.schema, + self.right.input_buffer.num_rows(), + self.left.input_buffer.schema(), + self.join_type, + &self.column_indices, + )?; + + // Combine the left and right results: + let result = combine_two_batches(&self.schema, left_result, right_result)?; + + // Update the metrics and return the result: + if let Some(batch) = &result { + // Update the metrics: + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); + return Ok(StatefulStreamResult::Ready(result)); + } + Ok(StatefulStreamResult::Continue) + } + + fn right_stream(&mut self) -> &mut SendableRecordBatchStream { + &mut self.right_stream + } + + fn left_stream(&mut self) -> &mut SendableRecordBatchStream { + &mut self.left_stream + } + + fn set_state(&mut self, state: EagerJoinStreamState) { + self.state = state; + } + + fn state(&mut self) -> EagerJoinStreamState { + self.state.clone() + } +} + impl SymmetricHashJoinStream { fn size(&self) -> usize { let mut size = 0; - size += std::mem::size_of_val(&self.input_stream); size += std::mem::size_of_val(&self.schema); size += std::mem::size_of_val(&self.filter); size += std::mem::size_of_val(&self.join_type); @@ -1040,194 +1064,138 @@ impl SymmetricHashJoinStream { size += std::mem::size_of_val(&self.random_state); size += std::mem::size_of_val(&self.null_equals_null); size += std::mem::size_of_val(&self.metrics); - size += std::mem::size_of_val(&self.final_result); size } - /// Polls the next result of the join operation. - /// - /// If the result of the join is ready, it returns the next record batch. - /// If the join has completed and there are no more results, it returns - /// `Poll::Ready(None)`. If the join operation is not complete, but the - /// current stream is not ready yet, it returns `Poll::Pending`. - fn poll_next_impl( + + /// Performs a join operation for the specified `probe_side` (either left or right). + /// This function: + /// 1. Determines which side is the probe and which is the build side. + /// 2. Updates metrics based on the batch that was polled. + /// 3. Executes the join with the given `probe_batch`. + /// 4. Optionally computes anti-join results if all conditions are met. + /// 5. Combines the results and returns a combined batch or `None` if no batch was produced. + fn perform_join_for_given_side( &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll>> { - loop { - // Poll the next batch from `input_stream`: - match self.input_stream.poll_next_unpin(cx) { - // Batch is available - Poll::Ready(Some((side, Ok(probe_batch)))) => { - // Determine which stream should be polled next. The side the - // RecordBatch comes from becomes the probe side. - let ( - probe_hash_joiner, - build_hash_joiner, - probe_side_sorted_filter_expr, - build_side_sorted_filter_expr, - probe_side_metrics, - ) = if side.eq(&JoinSide::Left) { - ( - &mut self.left, - &mut self.right, - &mut self.left_sorted_filter_expr, - &mut self.right_sorted_filter_expr, - &mut self.metrics.left, - ) - } else { - ( - &mut self.right, - &mut self.left, - &mut self.right_sorted_filter_expr, - &mut self.left_sorted_filter_expr, - &mut self.metrics.right, - ) - }; - // Update the metrics for the stream that was polled: - probe_side_metrics.input_batches.add(1); - probe_side_metrics.input_rows.add(probe_batch.num_rows()); - // Update the internal state of the hash joiner for the build side: - probe_hash_joiner - .update_internal_state(&probe_batch, &self.random_state)?; - // Join the two sides: - let equal_result = join_with_probe_batch( - build_hash_joiner, - probe_hash_joiner, - &self.schema, - self.join_type, - self.filter.as_ref(), - &probe_batch, - &self.column_indices, - &self.random_state, - self.null_equals_null, - )?; - // Increment the offset for the probe hash joiner: - probe_hash_joiner.offset += probe_batch.num_rows(); - - let anti_result = if let ( - Some(build_side_sorted_filter_expr), - Some(probe_side_sorted_filter_expr), - Some(graph), - ) = ( - build_side_sorted_filter_expr.as_mut(), - probe_side_sorted_filter_expr.as_mut(), - self.graph.as_mut(), - ) { - // Calculate filter intervals: - calculate_filter_expr_intervals( - &build_hash_joiner.input_buffer, - build_side_sorted_filter_expr, - &probe_batch, - probe_side_sorted_filter_expr, - )?; - let prune_length = build_hash_joiner - .calculate_prune_length_with_probe_batch( - build_side_sorted_filter_expr, - probe_side_sorted_filter_expr, - graph, - )?; - - if prune_length > 0 { - let res = build_side_determined_results( - build_hash_joiner, - &self.schema, - prune_length, - probe_batch.schema(), - self.join_type, - &self.column_indices, - )?; - build_hash_joiner.prune_internal_state(prune_length)?; - res - } else { - None - } - } else { - None - }; - - // Combine results: - let result = - combine_two_batches(&self.schema, equal_result, anti_result)?; - let capacity = self.size(); - self.metrics.stream_memory_usage.set(capacity); - self.reservation.lock().try_resize(capacity)?; - // Update the metrics if we have a batch; otherwise, continue the loop. - if let Some(batch) = &result { - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - return Poll::Ready(Ok(result).transpose()); - } - } - Poll::Ready(Some((_, Err(e)))) => return Poll::Ready(Some(Err(e))), - Poll::Ready(None) => { - // If the final result has already been obtained, return `Poll::Ready(None)`: - if self.final_result { - return Poll::Ready(None); - } - self.final_result = true; - // Get the left side results: - let left_result = build_side_determined_results( - &self.left, - &self.schema, - self.left.input_buffer.num_rows(), - self.right.input_buffer.schema(), - self.join_type, - &self.column_indices, - )?; - // Get the right side results: - let right_result = build_side_determined_results( - &self.right, - &self.schema, - self.right.input_buffer.num_rows(), - self.left.input_buffer.schema(), - self.join_type, - &self.column_indices, - )?; - - // Combine the left and right results: - let result = - combine_two_batches(&self.schema, left_result, right_result)?; - - // Update the metrics and return the result: - if let Some(batch) = &result { - // Update the metrics: - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - return Poll::Ready(Ok(result).transpose()); - } - } - Poll::Pending => return Poll::Pending, - } + probe_batch: RecordBatch, + probe_side: JoinSide, + ) -> Result> { + let ( + probe_hash_joiner, + build_hash_joiner, + probe_side_sorted_filter_expr, + build_side_sorted_filter_expr, + probe_side_metrics, + ) = if probe_side.eq(&JoinSide::Left) { + ( + &mut self.left, + &mut self.right, + &mut self.left_sorted_filter_expr, + &mut self.right_sorted_filter_expr, + &mut self.metrics.left, + ) + } else { + ( + &mut self.right, + &mut self.left, + &mut self.right_sorted_filter_expr, + &mut self.left_sorted_filter_expr, + &mut self.metrics.right, + ) + }; + // Update the metrics for the stream that was polled: + probe_side_metrics.input_batches.add(1); + probe_side_metrics.input_rows.add(probe_batch.num_rows()); + // Update the internal state of the hash joiner for the build side: + probe_hash_joiner.update_internal_state(&probe_batch, &self.random_state)?; + // Join the two sides: + let equal_result = join_with_probe_batch( + build_hash_joiner, + probe_hash_joiner, + &self.schema, + self.join_type, + self.filter.as_ref(), + &probe_batch, + &self.column_indices, + &self.random_state, + self.null_equals_null, + )?; + // Increment the offset for the probe hash joiner: + probe_hash_joiner.offset += probe_batch.num_rows(); + + let anti_result = if let ( + Some(build_side_sorted_filter_expr), + Some(probe_side_sorted_filter_expr), + Some(graph), + ) = ( + build_side_sorted_filter_expr.as_mut(), + probe_side_sorted_filter_expr.as_mut(), + self.graph.as_mut(), + ) { + // Calculate filter intervals: + calculate_filter_expr_intervals( + &build_hash_joiner.input_buffer, + build_side_sorted_filter_expr, + &probe_batch, + probe_side_sorted_filter_expr, + )?; + let prune_length = build_hash_joiner + .calculate_prune_length_with_probe_batch( + build_side_sorted_filter_expr, + probe_side_sorted_filter_expr, + graph, + )?; + let result = build_side_determined_results( + build_hash_joiner, + &self.schema, + prune_length, + probe_batch.schema(), + self.join_type, + &self.column_indices, + )?; + build_hash_joiner.prune_internal_state(prune_length)?; + result + } else { + None + }; + + // Combine results: + let result = combine_two_batches(&self.schema, equal_result, anti_result)?; + let capacity = self.size(); + self.metrics.stream_memory_usage.set(capacity); + self.reservation.lock().try_resize(capacity)?; + // Update the metrics if we have a batch; otherwise, continue the loop. + if let Some(batch) = &result { + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); } + Ok(result) } } #[cfg(test)] mod tests { + use std::collections::HashMap; + use std::sync::Mutex; + use super::*; + use crate::joins::test_utils::{ + build_sides_record_batches, compare_batches, complicated_filter, + create_memory_table, join_expr_tests_fixture_f64, join_expr_tests_fixture_i32, + join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter, + partitioned_sym_join_with_filter, split_record_batches, + }; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use datafusion_execution::config::SessionConfig; - use rstest::*; - use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{binary, col, Column}; - use crate::joins::hash_join_utils::tests::complicated_filter; - - use crate::joins::test_utils::{ - build_sides_record_batches, compare_batches, create_memory_table, - join_expr_tests_fixture_f64, join_expr_tests_fixture_i32, - join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter, - partitioned_sym_join_with_filter, split_record_batches, - }; + use once_cell::sync::Lazy; + use rstest::*; const TABLE_SIZE: i32 = 30; - use once_cell::sync::Lazy; - use std::collections::HashMap; - use std::sync::Mutex; - type TableKey = (i32, i32, usize); // (cardinality.0, cardinality.1, batch_size) type TableValue = (Vec, Vec); // (left, right) @@ -1839,6 +1807,73 @@ mod tests { Ok(()) } + #[tokio::test(flavor = "multi_thread")] + async fn complex_join_all_one_ascending_equivalence() -> Result<()> { + let cardinality = (3, 4); + let join_type = JoinType::Full; + + // a + b > c + 10 AND a + b < c + 100 + let config = SessionConfig::new().with_repartition_joins(false); + // let session_ctx = SessionContext::with_config(config); + // let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default().with_session_config(config)); + let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?; + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let left_sorted = vec![ + vec![PhysicalSortExpr { + expr: col("la1", left_schema)?, + options: SortOptions::default(), + }], + vec![PhysicalSortExpr { + expr: col("la2", left_schema)?, + options: SortOptions::default(), + }], + ]; + + let right_sorted = vec![PhysicalSortExpr { + expr: col("ra1", right_schema)?, + options: SortOptions::default(), + }]; + + let (left, right) = create_memory_table( + left_partition, + right_partition, + left_sorted, + vec![right_sorted], + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&intermediate_schema)?; + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 4, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } + #[rstest] #[tokio::test(flavor = "multi_thread")] async fn testing_with_temporal_columns( @@ -1858,7 +1893,7 @@ mod tests { (12, 17), )] cardinality: (i32, i32), - #[values(0, 1)] case_expr: usize, + #[values(0, 1, 2)] case_expr: usize, ) -> Result<()> { let session_config = SessionConfig::new().with_repartition_joins(false); let task_ctx = TaskContext::default().with_session_config(session_config); @@ -1923,6 +1958,7 @@ mod tests { experiment(left, right, Some(filter), join_type, on, task_ctx).await?; Ok(()) } + #[rstest] #[tokio::test(flavor = "multi_thread")] async fn test_with_interval_columns( diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index bb4a86199112..fbd52ddf0c70 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -17,6 +17,9 @@ //! This file has test utils for hash joins +use std::sync::Arc; +use std::usize; + use crate::joins::utils::{JoinFilter, JoinOn}; use crate::joins::{ HashJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, @@ -24,24 +27,24 @@ use crate::joins::{ use crate::memory::MemoryExec; use crate::repartition::RepartitionExec; use crate::{common, ExecutionPlan, Partitioning}; + use arrow::util::pretty::pretty_format_batches; use arrow_array::{ ArrayRef, Float64Array, Int32Array, IntervalDayTimeArray, RecordBatch, TimestampMillisecondArray, }; -use arrow_schema::Schema; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use arrow_schema::{DataType, Schema}; +use datafusion_common::{Result, ScalarValue}; use datafusion_execution::TaskContext; use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::expressions::{binary, cast, col, lit}; use datafusion_physical_expr::intervals::test_utils::{ gen_conjunctive_numerical_expr, gen_conjunctive_temporal_expr, }; use datafusion_physical_expr::{LexOrdering, PhysicalExpr}; + use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; -use std::sync::Arc; -use std::usize; pub fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) { // compare @@ -240,6 +243,20 @@ pub fn join_expr_tests_fixture_temporal( ScalarValue::TimestampMillisecond(Some(1672574402000), None), // 2023-01-01:12.00.02 schema, ), + // constructs ((left_col - DURATION '3 secs') > (right_col - DURATION '2 secs')) AND ((left_col - DURATION '5 secs') < (right_col - DURATION '4 secs')) + 2 => gen_conjunctive_temporal_expr( + left_col, + right_col, + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + ScalarValue::DurationMillisecond(Some(3000)), // 3 secs + ScalarValue::DurationMillisecond(Some(2000)), // 2 secs + ScalarValue::DurationMillisecond(Some(5000)), // 5 secs + ScalarValue::DurationMillisecond(Some(4000)), // 4 secs + schema, + ), _ => unreachable!(), } } @@ -500,3 +517,51 @@ pub fn create_memory_table( .with_sort_information(right_sorted); Ok((Arc::new(left), Arc::new(right))) } + +/// Filter expr for a + b > c + 10 AND a + b < c + 100 +pub(crate) fn complicated_filter( + filter_schema: &Schema, +) -> Result> { + let left_expr = binary( + cast( + binary( + col("0", filter_schema)?, + Operator::Plus, + col("1", filter_schema)?, + filter_schema, + )?, + filter_schema, + DataType::Int64, + )?, + Operator::Gt, + binary( + cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, + Operator::Plus, + lit(ScalarValue::Int64(Some(10))), + filter_schema, + )?, + filter_schema, + )?; + + let right_expr = binary( + cast( + binary( + col("0", filter_schema)?, + Operator::Plus, + col("1", filter_schema)?, + filter_schema, + )?, + filter_schema, + DataType::Int64, + )?, + Operator::Lt, + binary( + cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, + Operator::Plus, + lit(ScalarValue::Int64(Some(100))), + filter_schema, + )?, + filter_schema, + )?; + binary(left_expr, Operator::And, right_expr, filter_schema) +} diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index daaa16e0552d..1e3cf5abb477 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -17,46 +17,244 @@ //! Join related functionality used both on logical and physical plans -use std::cmp::max; use std::collections::HashSet; -use std::fmt::{Display, Formatter}; +use std::fmt::{self, Debug}; use std::future::Future; +use std::ops::IndexMut; use std::sync::Arc; use std::task::{Context, Poll}; use std::usize; use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; -use crate::SchemaRef; -use crate::{ - ColumnStatistics, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics, -}; +use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics}; use arrow::array::{ downcast_array, new_null_array, Array, BooleanBufferBuilder, UInt32Array, - UInt32Builder, UInt64Array, + UInt32BufferBuilder, UInt32Builder, UInt64Array, UInt64BufferBuilder, }; use arrow::compute; use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; use datafusion_common::cast::as_boolean_array; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::stats::Precision; use datafusion_common::{ - exec_err, plan_err, DataFusionError, JoinType, Result, ScalarValue, SharedResult, + plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult, }; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_physical_expr::equivalence::add_offset_to_expr; use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::utils::merge_vectors; use datafusion_physical_expr::{ - add_offset_to_lex_ordering, EquivalentClass, LexOrdering, LexOrderingRef, - OrderingEquivalenceProperties, OrderingEquivalentClass, PhysicalExpr, - PhysicalSortExpr, + LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalSortExpr, }; -use crate::joins::hash_join_utils::{build_filter_input_order, SortedFilterExpr}; -use datafusion_physical_expr::intervals::ExprIntervalGraph; -use datafusion_physical_expr::utils::merge_vectors; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; +use hashbrown::raw::RawTable; use parking_lot::Mutex; +/// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. +/// +/// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, +/// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. +/// +/// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 +/// As the key is a hash value, we need to check possible hash collisions in the probe stage +/// During this stage it might be the case that a row is contained the same hashmap value, +/// but the values don't match. Those are checked in the [`equal_rows_arr`](crate::joins::hash_join::equal_rows_arr) method. +/// +/// The indices (values) are stored in a separate chained list stored in the `Vec`. +/// +/// The first value (+1) is stored in the hashmap, whereas the next value is stored in array at the position value. +/// +/// The chain can be followed until the value "0" has been reached, meaning the end of the list. +/// Also see chapter 5.3 of [Balancing vectorized query execution with bandwidth-optimized storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487) +/// +/// # Example +/// +/// ``` text +/// See the example below: +/// +/// Insert (10,1) <-- insert hash value 10 with row index 1 +/// map: +/// ---------- +/// | 10 | 2 | +/// ---------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 0 | 0 | +/// --------------------- +/// Insert (20,2) +/// map: +/// ---------- +/// | 10 | 2 | +/// | 20 | 3 | +/// ---------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 0 | 0 | +/// --------------------- +/// Insert (10,3) <-- collision! row index 3 has a hash value of 10 as well +/// map: +/// ---------- +/// | 10 | 4 | +/// | 20 | 3 | +/// ---------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 2 | 0 | <--- hash value 10 maps to 4,2 (which means indices values 3,1) +/// --------------------- +/// Insert (10,4) <-- another collision! row index 4 ALSO has a hash value of 10 +/// map: +/// --------- +/// | 10 | 5 | +/// | 20 | 3 | +/// --------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means indices values 4,3,1) +/// --------------------- +/// ``` +pub struct JoinHashMap { + // Stores hash value to last row index + map: RawTable<(u64, u64)>, + // Stores indices in chained list data structure + next: Vec, +} + +impl JoinHashMap { + #[cfg(test)] + pub(crate) fn new(map: RawTable<(u64, u64)>, next: Vec) -> Self { + Self { map, next } + } + + pub(crate) fn with_capacity(capacity: usize) -> Self { + JoinHashMap { + map: RawTable::with_capacity(capacity), + next: vec![0; capacity], + } + } +} + +// Trait defining methods that must be implemented by a hash map type to be used for joins. +pub trait JoinHashMapType { + /// The type of list used to store the next list + type NextType: IndexMut; + /// Extend with zero + fn extend_zero(&mut self, len: usize); + /// Returns mutable references to the hash map and the next. + fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType); + /// Returns a reference to the hash map. + fn get_map(&self) -> &RawTable<(u64, u64)>; + /// Returns a reference to the next. + fn get_list(&self) -> &Self::NextType; + + /// Updates hashmap from iterator of row indices & row hashes pairs. + fn update_from_iter<'a>( + &mut self, + iter: impl Iterator, + deleted_offset: usize, + ) { + let (mut_map, mut_list) = self.get_mut(); + for (row, hash_value) in iter { + let item = mut_map.get_mut(*hash_value, |(hash, _)| *hash_value == *hash); + if let Some((_, index)) = item { + // Already exists: add index to next array + let prev_index = *index; + // Store new value inside hashmap + *index = (row + 1) as u64; + // Update chained Vec at `row` with previous value + mut_list[row - deleted_offset] = prev_index; + } else { + mut_map.insert( + *hash_value, + // store the value + 1 as 0 value reserved for end of list + (*hash_value, (row + 1) as u64), + |(hash, _)| *hash, + ); + // chained list at `row` is already initialized with 0 + // meaning end of list + } + } + } + + /// Returns all pairs of row indices matched by hash. + /// + /// This method only compares hashes, so additional further check for actual values + /// equality may be required. + fn get_matched_indices<'a>( + &self, + iter: impl Iterator, + deleted_offset: Option, + ) -> (UInt32BufferBuilder, UInt64BufferBuilder) { + let mut input_indices = UInt32BufferBuilder::new(0); + let mut match_indices = UInt64BufferBuilder::new(0); + + let hash_map = self.get_map(); + let next_chain = self.get_list(); + for (row_idx, hash_value) in iter { + // Get the hash and find it in the index + if let Some((_, index)) = + hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) + { + let mut i = *index - 1; + loop { + let match_row_idx = if let Some(offset) = deleted_offset { + // This arguments means that we prune the next index way before here. + if i < offset as u64 { + // End of the list due to pruning + break; + } + i - offset as u64 + } else { + i + }; + match_indices.append(match_row_idx); + input_indices.append(row_idx as u32); + // Follow the chain to get the next index value + let next = next_chain[match_row_idx as usize]; + if next == 0 { + // end of list + break; + } + i = next - 1; + } + } + } + + (input_indices, match_indices) + } +} + +/// Implementation of `JoinHashMapType` for `JoinHashMap`. +impl JoinHashMapType for JoinHashMap { + type NextType = Vec; + + // Void implementation + fn extend_zero(&mut self, _: usize) {} + + /// Get mutable references to the hash map and the next. + fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType) { + (&mut self.map, &mut self.next) + } + + /// Get a reference to the hash map. + fn get_map(&self) -> &RawTable<(u64, u64)> { + &self.map + } + + /// Get a reference to the next. + fn get_list(&self) -> &Self::NextType { + &self.next + } +} + +impl fmt::Debug for JoinHashMap { + fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { + Ok(()) + } +} + /// The on clause of the join, as vector of (left, right) columns. pub type JoinOn = Vec<(Column, Column)>; /// Reference for JoinOn. @@ -96,8 +294,8 @@ fn check_join_set_is_valid( if !left_missing.is_empty() | !right_missing.is_empty() { return plan_err!( - "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {left_missing:?}\nMissing on the right: {right_missing:?}" - ); + "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {left_missing:?}\nMissing on the right: {right_missing:?}" + ); }; Ok(()) @@ -137,17 +335,8 @@ pub fn adjust_right_output_partitioning( Partitioning::Hash(exprs, size) => { let new_exprs = exprs .into_iter() - .map(|expr| { - expr.transform_down(&|e| match e.as_any().downcast_ref::() { - Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( - col.name(), - left_columns_len + col.index(), - )))), - None => Ok(Transformed::No(e)), - }) - .unwrap() - }) - .collect::>(); + .map(|expr| add_offset_to_expr(expr, left_columns_len)) + .collect(); Partitioning::Hash(new_exprs, size) } } @@ -182,24 +371,23 @@ pub fn calculate_join_output_ordering( left_columns_len: usize, maintains_input_order: &[bool], probe_side: Option, -) -> Result> { - // All joins have 2 children: - assert_eq!(maintains_input_order.len(), 2); - let left_maintains = maintains_input_order[0]; - let right_maintains = maintains_input_order[1]; +) -> Option { let mut right_ordering = match join_type { // In the case below, right ordering should be offseted with the left // side length, since we append the right table to the left table. JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { - add_offset_to_lex_ordering(right_ordering, left_columns_len)? + right_ordering + .iter() + .map(|sort_expr| PhysicalSortExpr { + expr: add_offset_to_expr(sort_expr.expr.clone(), left_columns_len), + options: sort_expr.options, + }) + .collect() } _ => right_ordering.to_vec(), }; - let output_ordering = match (left_maintains, right_maintains) { - (true, true) => { - return exec_err!("Cannot maintain ordering of both sides"); - } - (true, false) => { + let output_ordering = match maintains_input_order { + [true, false] => { // Special case, we can prefix ordering of right side with the ordering of left side. if join_type == JoinType::Inner && probe_side == Some(JoinSide::Left) { replace_on_columns_of_right_ordering( @@ -212,7 +400,7 @@ pub fn calculate_join_output_ordering( left_ordering.to_vec() } } - (false, true) => { + [false, true] => { // Special case, we can prefix ordering of left side with the ordering of right side. if join_type == JoinType::Inner && probe_side == Some(JoinSide::Right) { replace_on_columns_of_right_ordering( @@ -226,269 +414,15 @@ pub fn calculate_join_output_ordering( } } // Doesn't maintain ordering, output ordering is None. - (false, false) => return Ok(None), - }; - Ok((!output_ordering.is_empty()).then_some(output_ordering)) -} - -/// Combine equivalence properties of the given join inputs. -pub fn combine_join_equivalence_properties( - join_type: JoinType, - left_properties: EquivalenceProperties, - right_properties: EquivalenceProperties, - left_columns_len: usize, - on: &[(Column, Column)], - schema: SchemaRef, -) -> EquivalenceProperties { - let mut new_properties = EquivalenceProperties::new(schema); - match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { - new_properties.extend(left_properties.classes().to_vec()); - let new_right_properties = right_properties - .classes() - .iter() - .map(|prop| { - let new_head = Column::new( - prop.head().name(), - left_columns_len + prop.head().index(), - ); - let new_others = prop - .others() - .iter() - .map(|col| { - Column::new(col.name(), left_columns_len + col.index()) - }) - .collect::>(); - EquivalentClass::new(new_head, new_others) - }) - .collect::>(); - - new_properties.extend(new_right_properties); - } - JoinType::LeftSemi | JoinType::LeftAnti => { - new_properties.extend(left_properties.classes().to_vec()) - } - JoinType::RightSemi | JoinType::RightAnti => { - new_properties.extend(right_properties.classes().to_vec()) - } - } - - if join_type == JoinType::Inner { - on.iter().for_each(|(column1, column2)| { - let new_column2 = - Column::new(column2.name(), left_columns_len + column2.index()); - new_properties.add_equal_conditions((column1, &new_column2)) - }) - } - new_properties -} - -/// Calculate equivalence properties for the given cross join operation. -pub fn cross_join_equivalence_properties( - left_properties: EquivalenceProperties, - right_properties: EquivalenceProperties, - left_columns_len: usize, - schema: SchemaRef, -) -> EquivalenceProperties { - let mut new_properties = EquivalenceProperties::new(schema); - new_properties.extend(left_properties.classes().to_vec()); - let new_right_properties = right_properties - .classes() - .iter() - .map(|prop| { - let new_head = - Column::new(prop.head().name(), left_columns_len + prop.head().index()); - let new_others = prop - .others() - .iter() - .map(|col| Column::new(col.name(), left_columns_len + col.index())) - .collect::>(); - EquivalentClass::new(new_head, new_others) - }) - .collect::>(); - new_properties.extend(new_right_properties); - new_properties -} - -/// Update right table ordering equivalences so that: -/// - They point to valid indices at the output of the join schema, and -/// - They are normalized with respect to equivalence columns. -/// -/// To do so, we increment column indices by the size of the left table when -/// join schema consists of a combination of left and right schema (Inner, -/// Left, Full, Right joins). Then, we normalize the sort expressions of -/// ordering equivalences one by one. We make sure that each expression in the -/// ordering equivalence is either: -/// - The head of the one of the equivalent classes, or -/// - Doesn't have an equivalent column. -/// -/// This way; once we normalize an expression according to equivalence properties, -/// it can thereafter safely be used for ordering equivalence normalization. -fn get_updated_right_ordering_equivalent_class( - join_type: &JoinType, - right_oeq_class: &OrderingEquivalentClass, - left_columns_len: usize, - join_eq_properties: &EquivalenceProperties, -) -> Result { - match join_type { - // In these modes, indices of the right schema should be offset by - // the left table size. - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { - let right_oeq_class = right_oeq_class.add_offset(left_columns_len)?; - return Ok( - right_oeq_class.normalize_with_equivalence_properties(join_eq_properties) - ); - } - _ => {} + [false, false] => return None, + [true, true] => unreachable!("Cannot maintain ordering of both sides"), + _ => unreachable!("Join operators can not have more than two children"), }; - Ok(right_oeq_class.normalize_with_equivalence_properties(join_eq_properties)) -} - -/// Calculate ordering equivalence properties for the given join operation. -pub fn combine_join_ordering_equivalence_properties( - join_type: &JoinType, - left: &Arc, - right: &Arc, - schema: SchemaRef, - maintains_input_order: &[bool], - probe_side: Option, - join_eq_properties: EquivalenceProperties, -) -> Result { - let mut new_properties = OrderingEquivalenceProperties::new(schema); - let left_columns_len = left.schema().fields.len(); - let left_oeq_properties = left.ordering_equivalence_properties(); - let right_oeq_properties = right.ordering_equivalence_properties(); - // All joins have 2 children - assert_eq!(maintains_input_order.len(), 2); - let left_maintains = maintains_input_order[0]; - let right_maintains = maintains_input_order[1]; - match (left_maintains, right_maintains) { - (true, true) => { - return Err(DataFusionError::Plan( - "Cannot maintain ordering of both sides".to_string(), - )) - } - (true, false) => { - new_properties.extend(left_oeq_properties.oeq_class().cloned()); - // In this special case, right side ordering can be prefixed with left side ordering. - if let ( - Some(JoinSide::Left), - // right side have an ordering - Some(_), - JoinType::Inner, - Some(oeq_class), - ) = ( - probe_side, - right.output_ordering(), - join_type, - right_oeq_properties.oeq_class(), - ) { - let left_output_ordering = left.output_ordering().unwrap_or(&[]); - - let updated_right_oeq = get_updated_right_ordering_equivalent_class( - join_type, - oeq_class, - left_columns_len, - &join_eq_properties, - )?; - - // Right side ordering equivalence properties should be prepended with - // those of the left side while constructing output ordering equivalence - // properties since stream side is the left side. - // - // If the right table ordering equivalences contain `b ASC`, and the output - // ordering of the left table is `a ASC`, then the ordering equivalence `b ASC` - // for the right table should be converted to `a ASC, b ASC` before it is added - // to the ordering equivalences of the join. - let updated_right_oeq_class = updated_right_oeq - .prefix_ordering_equivalent_class_with_existing_ordering( - left_output_ordering, - &join_eq_properties, - ); - new_properties.extend(Some(updated_right_oeq_class)); - } - } - (false, true) => { - let updated_right_oeq = right_oeq_properties - .oeq_class() - .map(|right_oeq_class| { - get_updated_right_ordering_equivalent_class( - join_type, - right_oeq_class, - left_columns_len, - &join_eq_properties, - ) - }) - .transpose()?; - new_properties.extend(updated_right_oeq); - // In this special case, left side ordering can be prefixed with right side ordering. - if let ( - Some(JoinSide::Right), - // left side have an ordering - Some(_), - JoinType::Inner, - Some(left_oeq_class), - ) = ( - probe_side, - left.output_ordering(), - join_type, - left_oeq_properties.oeq_class(), - ) { - let right_output_ordering = right.output_ordering().unwrap_or(&[]); - let right_output_ordering = - add_offset_to_lex_ordering(right_output_ordering, left_columns_len)?; - - // Left side ordering equivalence properties should be prepended with - // those of the right side while constructing output ordering equivalence - // properties since stream side is the right side. - // - // If the right table ordering equivalences contain `b ASC`, and the output - // ordering of the left table is `a ASC`, then the ordering equivalence `b ASC` - // for the right table should be converted to `a ASC, b ASC` before it is added - // to the ordering equivalences of the join. - let updated_left_oeq_class = left_oeq_class - .prefix_ordering_equivalent_class_with_existing_ordering( - &right_output_ordering, - &join_eq_properties, - ); - new_properties.extend(Some(updated_left_oeq_class)); - } - } - (false, false) => {} - } - Ok(new_properties) -} - -impl Display for JoinSide { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - JoinSide::Left => write!(f, "left"), - JoinSide::Right => write!(f, "right"), - } - } -} - -/// Used in ColumnIndex to distinguish which side the index is for -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum JoinSide { - /// Left side of the join - Left, - /// Right side of the join - Right, -} - -impl JoinSide { - /// Inverse the join side - pub fn negate(&self) -> Self { - match self { - JoinSide::Left => JoinSide::Right, - JoinSide::Right => JoinSide::Left, - } - } + (!output_ordering.is_empty()).then_some(output_ordering) } /// Information about the index and placement (left or right) of the columns -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct ColumnIndex { /// Index of the column pub index: usize, @@ -726,21 +660,21 @@ pub(crate) fn estimate_join_statistics( right: Arc, on: JoinOn, join_type: &JoinType, -) -> Statistics { - let left_stats = left.statistics(); - let right_stats = right.statistics(); + schema: &Schema, +) -> Result { + let left_stats = left.statistics()?; + let right_stats = right.statistics()?; let join_stats = estimate_join_cardinality(join_type, left_stats, right_stats, &on); let (num_rows, column_statistics) = match join_stats { - Some(stats) => (Some(stats.num_rows), Some(stats.column_statistics)), - None => (None, None), + Some(stats) => (Precision::Inexact(stats.num_rows), stats.column_statistics), + None => (Precision::Absent, Statistics::unknown_column(schema)), }; - Statistics { + Ok(Statistics { num_rows, - total_byte_size: None, + total_byte_size: Precision::Absent, column_statistics, - is_exact: false, - } + }) } // Estimate the cardinality for the given join with input statistics. @@ -752,29 +686,27 @@ fn estimate_join_cardinality( ) -> Option { match join_type { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { - let left_num_rows = left_stats.num_rows?; - let right_num_rows = right_stats.num_rows?; - - // Take the left_col_stats and right_col_stats using the index - // obtained from index() method of the each element of 'on'. - let all_left_col_stats = left_stats.column_statistics?; - let all_right_col_stats = right_stats.column_statistics?; let (left_col_stats, right_col_stats) = on .iter() .map(|(left, right)| { ( - all_left_col_stats[left.index()].clone(), - all_right_col_stats[right.index()].clone(), + left_stats.column_statistics[left.index()].clone(), + right_stats.column_statistics[right.index()].clone(), ) }) .unzip::<_, _, Vec<_>, Vec<_>>(); let ij_cardinality = estimate_inner_join_cardinality( - left_num_rows, - right_num_rows, - left_col_stats, - right_col_stats, - left_stats.is_exact && right_stats.is_exact, + Statistics { + num_rows: left_stats.num_rows.clone(), + total_byte_size: Precision::Absent, + column_statistics: left_col_stats, + }, + Statistics { + num_rows: right_stats.num_rows.clone(), + total_byte_size: Precision::Absent, + column_statistics: right_col_stats, + }, )?; // The cardinality for inner join can also be used to estimate @@ -783,25 +715,25 @@ fn estimate_join_cardinality( // joins (so that we don't underestimate the cardinality). let cardinality = match join_type { JoinType::Inner => ij_cardinality, - JoinType::Left => max(ij_cardinality, left_num_rows), - JoinType::Right => max(ij_cardinality, right_num_rows), - JoinType::Full => { - max(ij_cardinality, left_num_rows) - + max(ij_cardinality, right_num_rows) - - ij_cardinality - } + JoinType::Left => ij_cardinality.max(&left_stats.num_rows), + JoinType::Right => ij_cardinality.max(&right_stats.num_rows), + JoinType::Full => ij_cardinality + .max(&left_stats.num_rows) + .add(&ij_cardinality.max(&right_stats.num_rows)) + .sub(&ij_cardinality), _ => unreachable!(), }; Some(PartialJoinStatistics { - num_rows: cardinality, + num_rows: *cardinality.get_value()?, // We don't do anything specific here, just combine the existing // statistics which might yield subpar results (although it is // true, esp regarding min/max). For a better estimation, we need // filter selectivity analysis first. - column_statistics: all_left_col_stats + column_statistics: left_stats + .column_statistics .into_iter() - .chain(all_right_col_stats) + .chain(right_stats.column_statistics) .collect(), }) } @@ -818,30 +750,47 @@ fn estimate_join_cardinality( /// a very conservative implementation that can quickly give up if there is not /// enough input statistics. fn estimate_inner_join_cardinality( - left_num_rows: usize, - right_num_rows: usize, - left_col_stats: Vec, - right_col_stats: Vec, - is_exact: bool, -) -> Option { + left_stats: Statistics, + right_stats: Statistics, +) -> Option> { // The algorithm here is partly based on the non-histogram selectivity estimation // from Spark's Catalyst optimizer. - - let mut join_selectivity = None; - for (left_stat, right_stat) in left_col_stats.iter().zip(right_col_stats.iter()) { - if (left_stat.min_value.clone()? > right_stat.max_value.clone()?) - || (left_stat.max_value.clone()? < right_stat.min_value.clone()?) - { - // If there is no overlap in any of the join columns, that means the join - // itself is disjoint and the cardinality is 0. Though we can only assume - // this when the statistics are exact (since it is a very strong assumption). - return if is_exact { Some(0) } else { None }; + let mut join_selectivity = Precision::Absent; + for (left_stat, right_stat) in left_stats + .column_statistics + .iter() + .zip(right_stats.column_statistics.iter()) + { + // If there is no overlap in any of the join columns, this means the join + // itself is disjoint and the cardinality is 0. Though we can only assume + // this when the statistics are exact (since it is a very strong assumption). + if left_stat.min_value.get_value()? > right_stat.max_value.get_value()? { + return Some( + if left_stat.min_value.is_exact().unwrap_or(false) + && right_stat.max_value.is_exact().unwrap_or(false) + { + Precision::Exact(0) + } else { + Precision::Inexact(0) + }, + ); + } + if left_stat.max_value.get_value()? < right_stat.min_value.get_value()? { + return Some( + if left_stat.max_value.is_exact().unwrap_or(false) + && right_stat.min_value.is_exact().unwrap_or(false) + { + Precision::Exact(0) + } else { + Precision::Inexact(0) + }, + ); } - let left_max_distinct = max_distinct_count(left_num_rows, left_stat.clone()); - let right_max_distinct = max_distinct_count(right_num_rows, right_stat.clone()); - let max_distinct = max(left_max_distinct, right_max_distinct); - if max_distinct > join_selectivity { + let left_max_distinct = max_distinct_count(&left_stats.num_rows, left_stat); + let right_max_distinct = max_distinct_count(&right_stats.num_rows, right_stat); + let max_distinct = left_max_distinct.max(&right_max_distinct); + if max_distinct.get_value().is_some() { // Seems like there are a few implementations of this algorithm that implement // exponential decay for the selectivity (like Hive's Optiq Optimizer). Needs // further exploration. @@ -852,9 +801,14 @@ fn estimate_inner_join_cardinality( // With the assumption that the smaller input's domain is generally represented in the bigger // input's domain, we can estimate the inner join's cardinality by taking the cartesian product // of the two inputs and normalizing it by the selectivity factor. + let left_num_rows = left_stats.num_rows.get_value()?; + let right_num_rows = right_stats.num_rows.get_value()?; match join_selectivity { - Some(selectivity) if selectivity > 0 => { - Some((left_num_rows * right_num_rows) / selectivity) + Precision::Exact(value) if value > 0 => { + Some(Precision::Exact((left_num_rows * right_num_rows) / value)) + } + Precision::Inexact(value) if value > 0 => { + Some(Precision::Inexact((left_num_rows * right_num_rows) / value)) } // Since we don't have any information about the selectivity (which is derived // from the number of distinct rows information) we can give up here for now. @@ -865,47 +819,61 @@ fn estimate_inner_join_cardinality( } /// Estimate the number of maximum distinct values that can be present in the -/// given column from its statistics. -/// -/// If distinct_count is available, uses it directly. If the column numeric, and -/// has min/max values, then they might be used as a fallback option. Otherwise, -/// returns None. -fn max_distinct_count(num_rows: usize, stats: ColumnStatistics) -> Option { - match (stats.distinct_count, stats.max_value, stats.min_value) { - (Some(_), _, _) => stats.distinct_count, - (_, Some(max), Some(min)) => { - // Note that float support is intentionally omitted here, since the computation - // of a range between two float values is not trivial and the result would be - // highly inaccurate. - let numeric_range = get_int_range(min, max)?; - - // The number can never be greater than the number of rows we have (minus - // the nulls, since they don't count as distinct values). - let ceiling = num_rows - stats.null_count.unwrap_or(0); - Some(numeric_range.min(ceiling)) - } - _ => None, - } -} +/// given column from its statistics. If distinct_count is available, uses it +/// directly. Otherwise, if the column is numeric and has min/max values, it +/// estimates the maximum distinct count from those. +fn max_distinct_count( + num_rows: &Precision, + stats: &ColumnStatistics, +) -> Precision { + match &stats.distinct_count { + dc @ (Precision::Exact(_) | Precision::Inexact(_)) => dc.clone(), + _ => { + // The number can never be greater than the number of rows we have + // minus the nulls (since they don't count as distinct values). + let result = match num_rows { + Precision::Absent => Precision::Absent, + Precision::Inexact(count) => { + Precision::Inexact(count - stats.null_count.get_value().unwrap_or(&0)) + } + Precision::Exact(count) => { + let count = count - stats.null_count.get_value().unwrap_or(&0); + if stats.null_count.is_exact().unwrap_or(false) { + Precision::Exact(count) + } else { + Precision::Inexact(count) + } + } + }; + // Cap the estimate using the number of possible values: + if let (Some(min), Some(max)) = + (stats.min_value.get_value(), stats.max_value.get_value()) + { + if let Some(range_dc) = Interval::try_new(min.clone(), max.clone()) + .ok() + .and_then(|e| e.cardinality()) + { + let range_dc = range_dc as usize; + // Note that the `unwrap` calls in the below statement are safe. + return if matches!(result, Precision::Absent) + || &range_dc < result.get_value().unwrap() + { + if stats.min_value.is_exact().unwrap() + && stats.max_value.is_exact().unwrap() + { + Precision::Exact(range_dc) + } else { + Precision::Inexact(range_dc) + } + } else { + result + }; + } + } -/// Return the numeric range between the given min and max values. -fn get_int_range(min: ScalarValue, max: ScalarValue) -> Option { - let delta = &max.sub(&min).ok()?; - match delta { - ScalarValue::Int8(Some(delta)) if *delta >= 0 => Some(*delta as usize), - ScalarValue::Int16(Some(delta)) if *delta >= 0 => Some(*delta as usize), - ScalarValue::Int32(Some(delta)) if *delta >= 0 => Some(*delta as usize), - ScalarValue::Int64(Some(delta)) if *delta >= 0 => Some(*delta as usize), - ScalarValue::UInt8(Some(delta)) => Some(*delta as usize), - ScalarValue::UInt16(Some(delta)) => Some(*delta as usize), - ScalarValue::UInt32(Some(delta)) => Some(*delta as usize), - ScalarValue::UInt64(Some(delta)) => Some(*delta as usize), - _ => None, + result + } } - // The delta (directly) is not the real range, since it does not include the - // first term. - // E.g. (min=2, max=4) -> (4 - 2) -> 2, but the actual result should be 3 (1, 2, 3). - .map(|open_ended_range| open_ended_range + 1) } enum OnceFutState { @@ -954,6 +922,22 @@ impl OnceFut { ), } } + + /// Get shared reference to the result of the computation if it is ready, without consuming it + pub(crate) fn get_shared(&mut self, cx: &mut Context<'_>) -> Poll>> { + if let OnceFutState::Pending(fut) = &mut self.state { + let r = ready!(fut.poll_unpin(cx)); + self.state = OnceFutState::Ready(r); + } + + match &self.state { + OnceFutState::Pending(_) => unreachable!(), + OnceFutState::Ready(r) => Poll::Ready( + r.clone() + .map_err(|e| DataFusionError::External(Box::new(e))), + ), + } + } } /// Some type `join_type` of join need to maintain the matched indices bit map for the left side, and @@ -1025,7 +1009,7 @@ pub(crate) fn apply_join_filter_to_indices( let filter_result = filter .expression() .evaluate(&intermediate_batch)? - .into_array(intermediate_batch.num_rows()); + .into_array(intermediate_batch.num_rows())?; let mask = as_boolean_array(&filter_result)?; let left_filtered = compute::filter(&build_indices, mask)?; @@ -1297,100 +1281,84 @@ impl BuildProbeJoinMetrics { } } -/// Updates sorted filter expressions with corresponding node indices from the -/// expression interval graph. +/// The `handle_state` macro is designed to process the result of a state-changing +/// operation, encountered e.g. in implementations of `EagerJoinStream`. It +/// operates on a `StatefulStreamResult` by matching its variants and executing +/// corresponding actions. This macro is used to streamline code that deals with +/// state transitions, reducing boilerplate and improving readability. /// -/// This function iterates through the provided sorted filter expressions, -/// gathers the corresponding node indices from the expression interval graph, -/// and then updates the sorted expressions with these indices. It ensures -/// that these sorted expressions are aligned with the structure of the graph. -fn update_sorted_exprs_with_node_indices( - graph: &mut ExprIntervalGraph, - sorted_exprs: &mut [SortedFilterExpr], -) { - // Extract filter expressions from the sorted expressions: - let filter_exprs = sorted_exprs - .iter() - .map(|expr| expr.filter_expr().clone()) - .collect::>(); - - // Gather corresponding node indices for the extracted filter expressions from the graph: - let child_node_indices = graph.gather_node_indices(&filter_exprs); - - // Iterate through the sorted expressions and the gathered node indices: - for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices) { - // Update each sorted expression with the corresponding node index: - sorted_expr.set_node_index(index); - } +/// # Cases +/// +/// - `Ok(StatefulStreamResult::Continue)`: Continues the loop, indicating the +/// stream join operation should proceed to the next step. +/// - `Ok(StatefulStreamResult::Ready(result))`: Returns a `Poll::Ready` with the +/// result, either yielding a value or indicating the stream is awaiting more +/// data. +/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue +/// during the stream join operation. +/// +/// # Arguments +/// +/// * `$match_case`: An expression that evaluates to a `Result>`. +#[macro_export] +macro_rules! handle_state { + ($match_case:expr) => { + match $match_case { + Ok(StatefulStreamResult::Continue) => continue, + Ok(StatefulStreamResult::Ready(result)) => { + Poll::Ready(Ok(result).transpose()) + } + Err(e) => Poll::Ready(Some(Err(e))), + } + }; } -/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. +/// The `handle_async_state` macro adapts the `handle_state` macro for use in +/// asynchronous operations, particularly when dealing with `Poll` results within +/// async traits like `EagerJoinStream`. It polls the asynchronous state-changing +/// function using `poll_unpin` and then passes the result to `handle_state` for +/// further processing. /// /// # Arguments /// -/// * `filter` - The join filter to base the sorting on. -/// * `left` - The left execution plan. -/// * `right` - The right execution plan. -/// * `left_sort_exprs` - The expressions to sort on the left side. -/// * `right_sort_exprs` - The expressions to sort on the right side. +/// * `$state_func`: An async function or future that returns a +/// `Result>`. +/// * `$cx`: The context to be passed for polling, usually of type `&mut Context`. +/// +#[macro_export] +macro_rules! handle_async_state { + ($state_func:expr, $cx:expr) => { + $crate::handle_state!(ready!($state_func.poll_unpin($cx))) + }; +} + +/// Represents the result of an operation on stateful join stream. /// -/// # Returns +/// This enumueration indicates whether the state produced a result that is +/// ready for use (`Ready`) or if the operation requires continuation (`Continue`). /// -/// * A tuple consisting of the sorted filter expression for the left and right sides, and an expression interval graph. -pub fn prepare_sorted_exprs( - filter: &JoinFilter, - left: &Arc, - right: &Arc, - left_sort_exprs: &[PhysicalSortExpr], - right_sort_exprs: &[PhysicalSortExpr], -) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { - // Build the filter order for the left side - let err = - || DataFusionError::Plan("Filter does not include the child order".to_owned()); - - let left_temp_sorted_filter_expr = build_filter_input_order( - JoinSide::Left, - filter, - &left.schema(), - &left_sort_exprs[0], - )? - .ok_or_else(err)?; - - // Build the filter order for the right side - let right_temp_sorted_filter_expr = build_filter_input_order( - JoinSide::Right, - filter, - &right.schema(), - &right_sort_exprs[0], - )? - .ok_or_else(err)?; - - // Collect the sorted expressions - let mut sorted_exprs = - vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr]; - - // Build the expression interval graph - let mut graph = ExprIntervalGraph::try_new(filter.expression().clone())?; - - // Update sorted expressions with node indices - update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs); - - // Swap and remove to get the final sorted filter expressions - let right_sorted_filter_expr = sorted_exprs.swap_remove(1); - let left_sorted_filter_expr = sorted_exprs.swap_remove(0); - - Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph)) +/// Variants: +/// - `Ready(T)`: Indicates that the operation is complete with a result of type `T`. +/// - `Continue`: Indicates that the operation is not yet complete and requires further +/// processing or more data. When this variant is returned, it typically means that the +/// current invocation of the state did not produce a final result, and the operation +/// should be invoked again later with more data and possibly with a different state. +pub enum StatefulStreamResult { + Ready(T), + Continue, } #[cfg(test)] mod tests { + use std::pin::Pin; + use super::*; - use arrow::datatypes::Fields; - use arrow::error::Result as ArrowResult; - use arrow::{datatypes::DataType, error::ArrowError}; + + use arrow::datatypes::{DataType, Fields}; + use arrow::error::{ArrowError, Result as ArrowResult}; use arrow_schema::SortOptions; - use datafusion_common::ScalarValue; - use std::pin::Pin; + + use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { let left = left @@ -1426,9 +1394,7 @@ mod tests { #[tokio::test] async fn check_error_nesting() { let once_fut = OnceFut::<()>::new(async { - Err(DataFusionError::ArrowError(ArrowError::CsvError( - "some error".to_string(), - ))) + arrow_err!(ArrowError::CsvError("some error".to_string())) }); struct TestFut(OnceFut<()>); @@ -1452,10 +1418,10 @@ mod tests { let wrapped_err = DataFusionError::from(arrow_err_from_fut); let root_err = wrapped_err.find_root(); - assert!(matches!( - root_err, - DataFusionError::ArrowError(ArrowError::CsvError(_)) - )) + let _expected = + arrow_datafusion_err!(ArrowError::CsvError("some error".to_owned())); + + assert!(matches!(root_err, _expected)) } #[test] @@ -1543,14 +1509,18 @@ mod tests { fn create_stats( num_rows: Option, - column_stats: Option>, + column_stats: Vec, is_exact: bool, ) -> Statistics { Statistics { - num_rows, + num_rows: if is_exact { + num_rows.map(Precision::Exact) + } else { + num_rows.map(Precision::Inexact) + } + .unwrap_or(Precision::Absent), column_statistics: column_stats, - is_exact, - ..Default::default() + total_byte_size: Precision::Absent, } } @@ -1560,9 +1530,15 @@ mod tests { distinct_count: Option, ) -> ColumnStatistics { ColumnStatistics { - distinct_count, - min_value: min.map(|size| ScalarValue::Int64(Some(size))), - max_value: max.map(|size| ScalarValue::Int64(Some(size))), + distinct_count: distinct_count + .map(Precision::Inexact) + .unwrap_or(Precision::Absent), + min_value: min + .map(|size| Precision::Inexact(ScalarValue::from(size))) + .unwrap_or(Precision::Absent), + max_value: max + .map(|size| Precision::Inexact(ScalarValue::from(size))) + .unwrap_or(Precision::Absent), ..Default::default() } } @@ -1574,7 +1550,7 @@ mod tests { // over the expected output (since it depends on join type to join type). #[test] fn test_inner_join_cardinality_single_column() -> Result<()> { - let cases: Vec<(PartialStats, PartialStats, Option)> = vec![ + let cases: Vec<(PartialStats, PartialStats, Option>)> = vec![ // ----------------------------------------------------------------------------- // | left(rows, min, max, distinct), right(rows, min, max, distinct), expected | // ----------------------------------------------------------------------------- @@ -1586,70 +1562,70 @@ mod tests { ( (10, Some(1), Some(10), None), (10, Some(1), Some(10), None), - Some(10), + Some(Precision::Inexact(10)), ), // range(left) > range(right) ( (10, Some(6), Some(10), None), (10, Some(8), Some(10), None), - Some(20), + Some(Precision::Inexact(20)), ), // range(right) > range(left) ( (10, Some(8), Some(10), None), (10, Some(6), Some(10), None), - Some(20), + Some(Precision::Inexact(20)), ), // range(left) > len(left), range(right) > len(right) ( (10, Some(1), Some(15), None), (20, Some(1), Some(40), None), - Some(10), + Some(Precision::Inexact(10)), ), // When we have distinct count. ( (10, Some(1), Some(10), Some(10)), (10, Some(1), Some(10), Some(10)), - Some(10), + Some(Precision::Inexact(10)), ), // distinct(left) > distinct(right) ( (10, Some(1), Some(10), Some(5)), (10, Some(1), Some(10), Some(2)), - Some(20), + Some(Precision::Inexact(20)), ), // distinct(right) > distinct(left) ( (10, Some(1), Some(10), Some(2)), (10, Some(1), Some(10), Some(5)), - Some(20), + Some(Precision::Inexact(20)), ), // min(left) < 0 (range(left) > range(right)) ( (10, Some(-5), Some(5), None), (10, Some(1), Some(5), None), - Some(10), + Some(Precision::Inexact(10)), ), // min(right) < 0, max(right) < 0 (range(right) > range(left)) ( (10, Some(-25), Some(-20), None), (10, Some(-25), Some(-15), None), - Some(10), + Some(Precision::Inexact(10)), ), // range(left) < 0, range(right) >= 0 // (there isn't a case where both left and right ranges are negative // so one of them is always going to work, this just proves negative // ranges with bigger absolute values are not are not accidentally used). ( - (10, Some(10), Some(0), None), + (10, Some(-10), Some(0), None), (10, Some(0), Some(10), Some(5)), - Some(20), // It would have been ten if we have used abs(range(left)) + Some(Precision::Inexact(10)), ), // range(left) = 1, range(right) = 1 ( (10, Some(1), Some(1), None), (10, Some(1), Some(1), None), - Some(100), + Some(Precision::Inexact(100)), ), // // Edge cases @@ -1674,22 +1650,12 @@ mod tests { ( (10, Some(0), Some(10), None), (10, Some(11), Some(20), None), - None, + Some(Precision::Inexact(0)), ), ( (10, Some(11), Some(20), None), (10, Some(0), Some(10), None), - None, - ), - ( - (10, Some(5), Some(10), Some(10)), - (10, Some(11), Some(3), Some(10)), - None, - ), - ( - (10, Some(10), Some(5), Some(10)), - (10, Some(3), Some(7), Some(10)), - None, + Some(Precision::Inexact(0)), ), // distinct(left) = 0, distinct(right) = 0 ( @@ -1713,13 +1679,18 @@ mod tests { assert_eq!( estimate_inner_join_cardinality( - left_num_rows, - right_num_rows, - left_col_stats.clone(), - right_col_stats.clone(), - false, + Statistics { + num_rows: Precision::Inexact(left_num_rows), + total_byte_size: Precision::Absent, + column_statistics: left_col_stats.clone(), + }, + Statistics { + num_rows: Precision::Inexact(right_num_rows), + total_byte_size: Precision::Absent, + column_statistics: right_col_stats.clone(), + }, ), - expected_cardinality + expected_cardinality.clone() ); // We should also be able to use join_cardinality to get the same results @@ -1727,18 +1698,22 @@ mod tests { let join_on = vec![(Column::new("a", 0), Column::new("b", 0))]; let partial_join_stats = estimate_join_cardinality( &join_type, - create_stats(Some(left_num_rows), Some(left_col_stats.clone()), false), - create_stats(Some(right_num_rows), Some(right_col_stats.clone()), false), + create_stats(Some(left_num_rows), left_col_stats.clone(), false), + create_stats(Some(right_num_rows), right_col_stats.clone(), false), &join_on, ); assert_eq!( - partial_join_stats.clone().map(|s| s.num_rows), - expected_cardinality + partial_join_stats + .clone() + .map(|s| Precision::Inexact(s.num_rows)), + expected_cardinality.clone() ); assert_eq!( partial_join_stats.map(|s| s.column_statistics), - expected_cardinality.map(|_| [left_col_stats, right_col_stats].concat()) + expected_cardinality + .clone() + .map(|_| [left_col_stats, right_col_stats].concat()) ); } Ok(()) @@ -1760,13 +1735,18 @@ mod tests { // count is 200, so we are going to pick it. assert_eq!( estimate_inner_join_cardinality( - 400, - 400, - left_col_stats, - right_col_stats, - false + Statistics { + num_rows: Precision::Inexact(400), + total_byte_size: Precision::Absent, + column_statistics: left_col_stats, + }, + Statistics { + num_rows: Precision::Inexact(400), + total_byte_size: Precision::Absent, + column_statistics: right_col_stats, + }, ), - Some((400 * 400) / 200) + Some(Precision::Inexact((400 * 400) / 200)) ); Ok(()) } @@ -1774,28 +1754,33 @@ mod tests { #[test] fn test_inner_join_cardinality_decimal_range() -> Result<()> { let left_col_stats = vec![ColumnStatistics { - distinct_count: None, - min_value: Some(ScalarValue::Decimal128(Some(32500), 14, 4)), - max_value: Some(ScalarValue::Decimal128(Some(35000), 14, 4)), + distinct_count: Precision::Absent, + min_value: Precision::Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)), + max_value: Precision::Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)), ..Default::default() }]; let right_col_stats = vec![ColumnStatistics { - distinct_count: None, - min_value: Some(ScalarValue::Decimal128(Some(33500), 14, 4)), - max_value: Some(ScalarValue::Decimal128(Some(34000), 14, 4)), + distinct_count: Precision::Absent, + min_value: Precision::Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)), + max_value: Precision::Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)), ..Default::default() }]; assert_eq!( estimate_inner_join_cardinality( - 100, - 100, - left_col_stats, - right_col_stats, - false + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Absent, + column_statistics: left_col_stats, + }, + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Absent, + column_statistics: right_col_stats, + }, ), - None + Some(Precision::Inexact(100)) ); Ok(()) } @@ -1840,8 +1825,8 @@ mod tests { let partial_join_stats = estimate_join_cardinality( &join_type, - create_stats(Some(1000), Some(left_col_stats.clone()), false), - create_stats(Some(2000), Some(right_col_stats.clone()), false), + create_stats(Some(1000), left_col_stats.clone(), false), + create_stats(Some(2000), right_col_stats.clone(), false), &join_on, ) .unwrap(); @@ -1905,8 +1890,8 @@ mod tests { for (join_type, expected_num_rows) in cases { let partial_join_stats = estimate_join_cardinality( &join_type, - create_stats(Some(1000), Some(left_col_stats.clone()), true), - create_stats(Some(2000), Some(right_col_stats.clone()), true), + create_stats(Some(1000), left_col_stats.clone(), true), + create_stats(Some(2000), right_col_stats.clone(), true), &join_on, ) .unwrap(); @@ -1920,84 +1905,6 @@ mod tests { Ok(()) } - #[test] - fn test_get_updated_right_ordering_equivalence_properties() -> Result<()> { - let join_type = JoinType::Inner; - - let options = SortOptions::default(); - let right_oeq_class = OrderingEquivalentClass::new( - vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("x", 0)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 1)), - options, - }, - ], - vec![vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 2)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("w", 3)), - options, - }, - ]], - ); - - let left_columns_len = 4; - - let fields: Fields = ["a", "b", "c", "d", "x", "y", "z", "w"] - .into_iter() - .map(|name| Field::new(name, DataType::Int32, true)) - .collect(); - - let mut join_eq_properties = - EquivalenceProperties::new(Arc::new(Schema::new(fields))); - join_eq_properties - .add_equal_conditions((&Column::new("a", 0), &Column::new("x", 4))); - join_eq_properties - .add_equal_conditions((&Column::new("d", 3), &Column::new("w", 7))); - - let result = get_updated_right_ordering_equivalent_class( - &join_type, - &right_oeq_class, - left_columns_len, - &join_eq_properties, - )?; - - let expected = OrderingEquivalentClass::new( - vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 5)), - options, - }, - ], - vec![vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 6)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 3)), - options, - }, - ]], - ); - - assert_eq!(result.head(), expected.head()); - assert_eq!(result.others(), expected.others()); - - Ok(()) - } - #[test] fn test_calculate_join_output_ordering() -> Result<()> { let options = SortOptions::default(); @@ -2090,7 +1997,7 @@ mod tests { left_columns_len, maintains_input_order, probe_side - )?, + ), expected[i] ); } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 76adf7611d6f..1dd1392b9d86 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -17,57 +17,112 @@ //! Traits for physical query plan, supporting parallel execution for partitioned relations. -mod visitor; -pub use self::metrics::Metric; -use self::metrics::MetricsSet; -use self::{ - coalesce_partitions::CoalescePartitionsExec, display::DisplayableExecutionPlan, -}; -pub use datafusion_common::{internal_err, ColumnStatistics, Statistics}; -use datafusion_common::{plan_err, Result}; -use datafusion_physical_expr::PhysicalSortExpr; -pub use visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +use crate::coalesce_partitions::CoalescePartitionsExec; +use crate::display::DisplayableExecutionPlan; +use crate::metrics::MetricsSet; +use crate::repartition::RepartitionExec; +use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; - +use datafusion_common::tree_node::Transformed; use datafusion_common::utils::DataPtr; -pub use datafusion_expr::Accumulator; -pub use datafusion_expr::ColumnarValue; -use datafusion_physical_expr::equivalence::OrderingEquivalenceProperties; -pub use display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; +use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::{ + EquivalenceProperties, PhysicalSortExpr, PhysicalSortRequirement, +}; + use futures::stream::TryStreamExt; -use std::fmt::Debug; use tokio::task::JoinSet; -use datafusion_common::tree_node::Transformed; -use datafusion_common::DataFusionError; -use std::any::Any; -use std::sync::Arc; +mod topk; +mod visitor; + +pub mod aggregates; +pub mod analyze; +pub mod coalesce_batches; +pub mod coalesce_partitions; +pub mod common; +pub mod display; +pub mod empty; +pub mod explain; +pub mod filter; +pub mod insert; +pub mod joins; +pub mod limit; +pub mod memory; +pub mod metrics; +mod ordering; +pub mod placeholder_row; +pub mod projection; +pub mod repartition; +pub mod sorts; +pub mod stream; +pub mod streaming; +pub mod tree_node; +pub mod udaf; +pub mod union; +pub mod unnest; +pub mod values; +pub mod windows; + +pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; +pub use crate::metrics::Metric; +pub use crate::ordering::InputOrderMode; +pub use crate::topk::TopK; +pub use crate::visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; + +use datafusion_common::config::ConfigOptions; +pub use datafusion_common::hash_utils; +pub use datafusion_common::utils::project_schema; +pub use datafusion_common::{internal_err, ColumnStatistics, Statistics}; +pub use datafusion_expr::{Accumulator, ColumnarValue}; +pub use datafusion_physical_expr::window::WindowExpr; +pub use datafusion_physical_expr::{ + expressions, functions, udf, AggregateExpr, Distribution, Partitioning, PhysicalExpr, +}; -// backwards compatibility +// Backwards compatibility +pub use crate::stream::EmptyRecordBatchStream; pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; -pub use stream::EmptyRecordBatchStream; -/// `ExecutionPlan` represent nodes in the DataFusion Physical Plan. +/// Represent nodes in the DataFusion Physical Plan. +/// +/// Calling [`execute`] produces an `async` [`SendableRecordBatchStream`] of +/// [`RecordBatch`] that incrementally computes a partition of the +/// `ExecutionPlan`'s output from its input. See [`Partitioning`] for more +/// details on partitioning. /// -/// Each `ExecutionPlan` is partition-aware and is responsible for -/// creating the actual `async` [`SendableRecordBatchStream`]s -/// of [`RecordBatch`] that incrementally compute the operator's -/// output from its input partition. +/// Methods such as [`schema`] and [`output_partitioning`] communicate +/// properties of this output to the DataFusion optimizer, and methods such as +/// [`required_input_distribution`] and [`required_input_ordering`] express +/// requirements of the `ExecutionPlan` from its input. /// /// [`ExecutionPlan`] can be displayed in a simplified form using the /// return value from [`displayable`] in addition to the (normally /// quite verbose) `Debug` output. +/// +/// [`execute`]: ExecutionPlan::execute +/// [`schema`]: ExecutionPlan::schema +/// [`output_partitioning`]: ExecutionPlan::output_partitioning +/// [`required_input_distribution`]: ExecutionPlan::required_input_distribution +/// [`required_input_ordering`]: ExecutionPlan::required_input_ordering pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { - /// Returns the execution plan as [`Any`](std::any::Any) so that it can be + /// Returns the execution plan as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; /// Get the schema for this execution plan fn schema(&self) -> SchemaRef; - /// Specifies the output partitioning scheme of this plan + /// Specifies how the output of this `ExecutionPlan` is split into + /// partitions. fn output_partitioning(&self) -> Partitioning; /// Specifies whether this plan generates an infinite stream of records. @@ -81,7 +136,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { } } - /// If the output of this operator within each partition is sorted, + /// If the output of this `ExecutionPlan` within each partition is sorted, /// returns `Some(keys)` with the description of how it was sorted. /// /// For example, Sort, (obviously) produces sorted output as does @@ -89,17 +144,19 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// produces sorted output if its input was sorted as it does not /// reorder the input rows, /// - /// It is safe to return `None` here if your operator does not + /// It is safe to return `None` here if your `ExecutionPlan` does not /// have any particular output order here fn output_ordering(&self) -> Option<&[PhysicalSortExpr]>; /// Specifies the data distribution requirements for all the - /// children for this operator, By default it's [[Distribution::UnspecifiedDistribution]] for each child, + /// children for this `ExecutionPlan`, By default it's [[Distribution::UnspecifiedDistribution]] for each child, fn required_input_distribution(&self) -> Vec { vec![Distribution::UnspecifiedDistribution; self.children().len()] } - /// Specifies the ordering requirements for all of the children + /// Specifies the ordering required for all of the children of this + /// `ExecutionPlan`. + /// /// For each child, it's the local ordering requirement within /// each partition rather than the global ordering /// @@ -110,7 +167,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { vec![None; self.children().len()] } - /// Returns `false` if this operator's implementation may reorder + /// Returns `false` if this `ExecutionPlan`'s implementation may reorder /// rows within or between partitions. /// /// For example, Projection, Filter, and Limit maintain the order @@ -124,19 +181,21 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// The default implementation returns `false` /// /// WARNING: if you override this default, you *MUST* ensure that - /// the operator's maintains the ordering invariant or else + /// the `ExecutionPlan`'s maintains the ordering invariant or else /// DataFusion may produce incorrect results. fn maintains_input_order(&self) -> Vec { vec![false; self.children().len()] } - /// Specifies whether the operator benefits from increased parallelization - /// at its input for each child. If set to `true`, this indicates that the - /// operator would benefit from partitioning its corresponding child - /// (and thus from more parallelism). For operators that do very little work - /// the overhead of extra parallelism may outweigh any benefits + /// Specifies whether the `ExecutionPlan` benefits from increased + /// parallelization at its input for each child. /// - /// The default implementation returns `true` unless this operator + /// If returns `true`, the `ExecutionPlan` would benefit from partitioning + /// its corresponding child (and thus from more parallelism). For + /// `ExecutionPlan` that do very little work the overhead of extra + /// parallelism may outweigh any benefits + /// + /// The default implementation returns `true` unless this `ExecutionPlan` /// has signalled it requires a single child input partition. fn benefits_from_input_partitioning(&self) -> Vec { // By default try to maximize parallelism with more CPUs if @@ -147,28 +206,215 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { .collect() } - /// Get the EquivalenceProperties within the plan + /// Get the [`EquivalenceProperties`] within the plan. + /// + /// Equivalence properties tell DataFusion what columns are known to be + /// equal, during various optimization passes. By default, this returns "no + /// known equivalences" which is always correct, but may cause DataFusion to + /// unnecessarily resort data. + /// + /// If this ExecutionPlan makes no changes to the schema of the rows flowing + /// through it or how columns within each row relate to each other, it + /// should return the equivalence properties of its input. For + /// example, since `FilterExec` may remove rows from its input, but does not + /// otherwise modify them, it preserves its input equivalence properties. + /// However, since `ProjectionExec` may calculate derived expressions, it + /// needs special handling. + /// + /// See also [`Self::maintains_input_order`] and [`Self::output_ordering`] + /// for related concepts. fn equivalence_properties(&self) -> EquivalenceProperties { EquivalenceProperties::new(self.schema()) } - /// Get the OrderingEquivalenceProperties within the plan - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - OrderingEquivalenceProperties::new(self.schema()) - } - - /// Get a list of child execution plans that provide the input for this plan. The returned list - /// will be empty for leaf nodes, will contain a single value for unary nodes, or two - /// values for binary nodes (such as joins). + /// Get a list of children `ExecutionPlan`s that act as inputs to this plan. + /// The returned list will be empty for leaf nodes such as scans, will contain + /// a single value for unary nodes, or two values for binary nodes (such as + /// joins). fn children(&self) -> Vec>; - /// Returns a new plan where all children were replaced by new plans. + /// Returns a new `ExecutionPlan` where all existing children were replaced + /// by the `children`, oi order fn with_new_children( self: Arc, children: Vec>, ) -> Result>; - /// creates an iterator + /// If supported, attempt to increase the partitioning of this `ExecutionPlan` to + /// produce `target_partitions` partitions. + /// + /// If the `ExecutionPlan` does not support changing its partitioning, + /// returns `Ok(None)` (the default). + /// + /// It is the `ExecutionPlan` can increase its partitioning, but not to the + /// `target_partitions`, it may return an ExecutionPlan with fewer + /// partitions. This might happen, for example, if each new partition would + /// be too small to be efficiently processed individually. + /// + /// The DataFusion optimizer attempts to use as many threads as possible by + /// repartitioning its inputs to match the target number of threads + /// available (`target_partitions`). Some data sources, such as the built in + /// CSV and Parquet readers, implement this method as they are able to read + /// from their input files in parallel, regardless of how the source data is + /// split amongst files. + fn repartitioned( + &self, + _target_partitions: usize, + _config: &ConfigOptions, + ) -> Result>> { + Ok(None) + } + + /// Begin execution of `partition`, returning a [`Stream`] of + /// [`RecordBatch`]es. + /// + /// # Notes + /// + /// The `execute` method itself is not `async` but it returns an `async` + /// [`futures::stream::Stream`]. This `Stream` should incrementally compute + /// the output, `RecordBatch` by `RecordBatch` (in a streaming fashion). + /// Most `ExecutionPlan`s should not do any work before the first + /// `RecordBatch` is requested from the stream. + /// + /// [`RecordBatchStreamAdapter`] can be used to convert an `async` + /// [`Stream`] into a [`SendableRecordBatchStream`]. + /// + /// Using `async` `Streams` allows for network I/O during execution and + /// takes advantage of Rust's built in support for `async` continuations and + /// crate ecosystem. + /// + /// [`Stream`]: futures::stream::Stream + /// [`StreamExt`]: futures::stream::StreamExt + /// [`TryStreamExt`]: futures::stream::TryStreamExt + /// [`RecordBatchStreamAdapter`]: crate::stream::RecordBatchStreamAdapter + /// + /// # Cancellation / Aborting Execution + /// + /// The [`Stream`] that is returned must ensure that any allocated resources + /// are freed when the stream itself is dropped. This is particularly + /// important for [`spawn`]ed tasks or threads. Unless care is taken to + /// "abort" such tasks, they may continue to consume resources even after + /// the plan is dropped, generating intermediate results that are never + /// used. + /// + /// See [`AbortOnDropSingle`], [`AbortOnDropMany`] and + /// [`RecordBatchReceiverStreamBuilder`] for structures to help ensure all + /// background tasks are cancelled. + /// + /// [`spawn`]: tokio::task::spawn + /// [`AbortOnDropSingle`]: crate::common::AbortOnDropSingle + /// [`AbortOnDropMany`]: crate::common::AbortOnDropMany + /// [`RecordBatchReceiverStreamBuilder`]: crate::stream::RecordBatchReceiverStreamBuilder + /// + /// # Implementation Examples + /// + /// While `async` `Stream`s have a non trivial learning curve, the + /// [`futures`] crate provides [`StreamExt`] and [`TryStreamExt`] + /// which help simplify many common operations. + /// + /// Here are some common patterns: + /// + /// ## Return Precomputed `RecordBatch` + /// + /// We can return a precomputed `RecordBatch` as a `Stream`: + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::RecordBatch; + /// # use arrow_schema::SchemaRef; + /// # use datafusion_common::Result; + /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + /// # use datafusion_physical_plan::memory::MemoryStream; + /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + /// struct MyPlan { + /// batch: RecordBatch, + /// } + /// + /// impl MyPlan { + /// fn execute( + /// &self, + /// partition: usize, + /// context: Arc + /// ) -> Result { + /// // use functions from futures crate convert the batch into a stream + /// let fut = futures::future::ready(Ok(self.batch.clone())); + /// let stream = futures::stream::once(fut); + /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.batch.schema(), stream))) + /// } + /// } + /// ``` + /// + /// ## Lazily (async) Compute `RecordBatch` + /// + /// We can also lazily compute a `RecordBatch` when the returned `Stream` is polled + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::RecordBatch; + /// # use arrow_schema::SchemaRef; + /// # use datafusion_common::Result; + /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + /// # use datafusion_physical_plan::memory::MemoryStream; + /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + /// struct MyPlan { + /// schema: SchemaRef, + /// } + /// + /// /// Returns a single batch when the returned stream is polled + /// async fn get_batch() -> Result { + /// todo!() + /// } + /// + /// impl MyPlan { + /// fn execute( + /// &self, + /// partition: usize, + /// context: Arc + /// ) -> Result { + /// let fut = get_batch(); + /// let stream = futures::stream::once(fut); + /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))) + /// } + /// } + /// ``` + /// + /// ## Lazily (async) create a Stream + /// + /// If you need to to create the return `Stream` using an `async` function, + /// you can do so by flattening the result: + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::RecordBatch; + /// # use arrow_schema::SchemaRef; + /// # use futures::TryStreamExt; + /// # use datafusion_common::Result; + /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + /// # use datafusion_physical_plan::memory::MemoryStream; + /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + /// struct MyPlan { + /// schema: SchemaRef, + /// } + /// + /// /// async function that returns a stream + /// async fn get_batch_stream() -> Result { + /// todo!() + /// } + /// + /// impl MyPlan { + /// fn execute( + /// &self, + /// partition: usize, + /// context: Arc + /// ) -> Result { + /// // A future that yields a stream + /// let fut = get_batch_stream(); + /// // Use TryStreamExt::try_flatten to flatten the stream of streams + /// let stream = futures::stream::once(fut).try_flatten(); + /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))) + /// } + /// } + /// ``` fn execute( &self, partition: usize, @@ -176,7 +422,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { ) -> Result; /// Return a snapshot of the set of [`Metric`]s for this - /// [`ExecutionPlan`]. + /// [`ExecutionPlan`]. If no `Metric`s are available, return None. /// /// While the values of the metrics in the returned /// [`MetricsSet`]s may change as execution progresses, the @@ -190,14 +436,18 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { None } - /// Returns the global output statistics for this `ExecutionPlan` node. - fn statistics(&self) -> Statistics; + /// Returns statistics for this `ExecutionPlan` node. If statistics are not + /// available, should return [`Statistics::new_unknown`] (the default), not + /// an error. + fn statistics(&self) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } } /// Indicate whether a data exchange is needed for the input of `plan`, which will be very helpful /// especially for the distributed engine to judge whether need to deal with shuffling. /// Currently there are 3 kinds of execution plan which needs data exchange -/// 1. RepartitionExec for changing the partition number between two operators +/// 1. RepartitionExec for changing the partition number between two `ExecutionPlan`s /// 2. CoalescePartitionsExec for collapsing all of the partitions into one without ordering guarantee /// 3. SortPreservingMergeExec for collapsing all of the sorted partitions into one with ordering guarantee pub fn need_data_exchange(plan: Arc) -> bool { @@ -259,7 +509,12 @@ pub async fn collect( common::collect(stream).await } -/// Execute the [ExecutionPlan] and return a single stream of results +/// Execute the [ExecutionPlan] and return a single stream of results. +/// +/// # Aborting Execution +/// +/// Dropping the stream will abort the execution of the query, and free up +/// any allocated resources pub fn execute_stream( plan: Arc, context: Arc, @@ -317,7 +572,13 @@ pub async fn collect_partitioned( Ok(batches) } -/// Execute the [ExecutionPlan] and return a vec with one stream per output partition +/// Execute the [ExecutionPlan] and return a vec with one stream per output +/// partition +/// +/// # Aborting Execution +/// +/// Dropping the stream will abort the execution of the query, and free up +/// any allocated resources pub fn execute_stream_partitioned( plan: Arc, context: Arc, @@ -341,45 +602,12 @@ pub fn unbounded_output(plan: &Arc) -> bool { .unwrap_or(true) } -use datafusion_physical_expr::expressions::Column; -pub use datafusion_physical_expr::window::WindowExpr; -pub use datafusion_physical_expr::{AggregateExpr, PhysicalExpr}; -pub use datafusion_physical_expr::{Distribution, Partitioning}; -use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement}; - -pub mod aggregates; -pub mod analyze; -pub mod coalesce_batches; -pub mod coalesce_partitions; -pub mod common; -pub mod display; -pub mod empty; -pub mod explain; -pub mod filter; -pub mod insert; -pub mod joins; -pub mod limit; -pub mod memory; -pub mod metrics; -pub mod projection; -pub mod repartition; -pub mod sorts; -pub mod stream; -pub mod streaming; -pub mod tree_node; -pub mod udaf; -pub mod union; -pub mod unnest; -pub mod values; -pub mod windows; - -use crate::repartition::RepartitionExec; -use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; -pub use datafusion_common::utils::project_schema; -use datafusion_execution::TaskContext; -pub use datafusion_physical_expr::{ - expressions, functions, hash_utils, ordering_equivalence_properties_helper, udf, -}; +/// Utility function yielding a string representation of the given [`ExecutionPlan`]. +pub fn get_plan_string(plan: &Arc) -> Vec { + let formatted = displayable(plan.as_ref()).indent(true).to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + actual.iter().map(|elem| elem.to_string()).collect() +} #[cfg(test)] pub mod test; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 922c3db0efc8..37e8ffd76159 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -22,23 +22,21 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::{ - DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, -}; - use super::expressions::PhysicalSortExpr; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{DisplayAs, RecordBatchStream, SendableRecordBatchStream, Statistics}; +use crate::{ + DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, +}; use arrow::array::ArrayRef; use arrow::datatypes::SchemaRef; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use datafusion_common::stats::Precision; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::OrderingEquivalenceProperties; -use futures::stream::Stream; -use futures::stream::StreamExt; +use futures::stream::{Stream, StreamExt}; use log::trace; /// Limit execution plan @@ -139,10 +137,6 @@ impl ExecutionPlan for GlobalLimitExec { self.input.equivalence_properties() } - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - self.input.ordering_equivalence_properties() - } - fn with_new_children( self: Arc, children: Vec>, @@ -191,51 +185,81 @@ impl ExecutionPlan for GlobalLimitExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - let input_stats = self.input.statistics(); + fn statistics(&self) -> Result { + let input_stats = self.input.statistics()?; let skip = self.skip; - // the maximum row number needs to be fetched - let max_row_num = self - .fetch - .map(|fetch| { - if fetch >= usize::MAX - skip { - usize::MAX - } else { - fetch + skip - } - }) - .unwrap_or(usize::MAX); - match input_stats { + let col_stats = Statistics::unknown_column(&self.schema()); + let fetch = self.fetch.unwrap_or(usize::MAX); + + let mut fetched_row_number_stats = Statistics { + num_rows: Precision::Exact(fetch), + column_statistics: col_stats.clone(), + total_byte_size: Precision::Absent, + }; + + let stats = match input_stats { Statistics { - num_rows: Some(nr), .. + num_rows: Precision::Exact(nr), + .. + } + | Statistics { + num_rows: Precision::Inexact(nr), + .. } => { if nr <= skip { // if all input data will be skipped, return 0 - Statistics { - num_rows: Some(0), - is_exact: input_stats.is_exact, - ..Default::default() + let mut skip_all_rows_stats = Statistics { + num_rows: Precision::Exact(0), + column_statistics: col_stats, + total_byte_size: Precision::Absent, + }; + if !input_stats.num_rows.is_exact().unwrap_or(false) { + // The input stats are inexact, so the output stats must be too. + skip_all_rows_stats = skip_all_rows_stats.into_inexact(); } - } else if nr <= max_row_num { - // if the input does not reach the "fetch" globally, return input stats + skip_all_rows_stats + } else if nr <= fetch && self.skip == 0 { + // if the input does not reach the "fetch" globally, and "skip" is zero + // (meaning the input and output are identical), return input stats. + // Can input_stats still be used, but adjusted, in the "skip != 0" case? input_stats + } else if nr - skip <= fetch { + // after "skip" input rows are skipped, the remaining rows are less than or equal to the + // "fetch" values, so `num_rows` must equal the remaining rows + let remaining_rows: usize = nr - skip; + let mut skip_some_rows_stats = Statistics { + num_rows: Precision::Exact(remaining_rows), + column_statistics: col_stats, + total_byte_size: Precision::Absent, + }; + if !input_stats.num_rows.is_exact().unwrap_or(false) { + // The input stats are inexact, so the output stats must be too. + skip_some_rows_stats = skip_some_rows_stats.into_inexact(); + } + skip_some_rows_stats } else { - // if the input is greater than the "fetch", the num_row will be the "fetch", + // if the input is greater than "fetch+skip", the num_rows will be the "fetch", // but we won't be able to predict the other statistics - Statistics { - num_rows: Some(max_row_num), - is_exact: input_stats.is_exact, - ..Default::default() + if !input_stats.num_rows.is_exact().unwrap_or(false) + || self.fetch.is_none() + { + // If the input stats are inexact, the output stats must be too. + // If the fetch value is `usize::MAX` because no LIMIT was specified, + // we also can't represent it as an exact value. + fetched_row_number_stats = + fetched_row_number_stats.into_inexact(); } + fetched_row_number_stats } } - _ => Statistics { - // the result output row number will always be no greater than the limit number - num_rows: Some(max_row_num), - is_exact: false, - ..Default::default() - }, - } + _ => { + // The result output `num_rows` will always be no greater than the limit number. + // Should `num_rows` be marked as `Absent` here when the `fetch` value is large, + // as the actual `num_rows` may be far away from the `fetch` value? + fetched_row_number_stats.into_inexact() + } + }; + Ok(stats) } } @@ -320,10 +344,6 @@ impl ExecutionPlan for LocalLimitExec { self.input.equivalence_properties() } - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - self.input.ordering_equivalence_properties() - } - fn unbounded_output(&self, _children: &[bool]) -> Result { Ok(false) } @@ -361,32 +381,53 @@ impl ExecutionPlan for LocalLimitExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - let input_stats = self.input.statistics(); - match input_stats { + fn statistics(&self) -> Result { + let input_stats = self.input.statistics()?; + let col_stats = Statistics::unknown_column(&self.schema()); + let stats = match input_stats { // if the input does not reach the limit globally, return input stats Statistics { - num_rows: Some(nr), .. + num_rows: Precision::Exact(nr), + .. + } + | Statistics { + num_rows: Precision::Inexact(nr), + .. } if nr <= self.fetch => input_stats, // if the input is greater than the limit, the num_row will be greater // than the limit because the partitions will be limited separatly // the statistic Statistics { - num_rows: Some(nr), .. + num_rows: Precision::Exact(nr), + .. } if nr > self.fetch => Statistics { - num_rows: Some(self.fetch), + num_rows: Precision::Exact(self.fetch), // this is not actually exact, but will be when GlobalLimit is applied // TODO stats: find a more explicit way to vehiculate this information - is_exact: input_stats.is_exact, - ..Default::default() + column_statistics: col_stats, + total_byte_size: Precision::Absent, + }, + Statistics { + num_rows: Precision::Inexact(nr), + .. + } if nr > self.fetch => Statistics { + num_rows: Precision::Inexact(self.fetch), + // this is not actually exact, but will be when GlobalLimit is applied + // TODO stats: find a more explicit way to vehiculate this information + column_statistics: col_stats, + total_byte_size: Precision::Absent, }, _ => Statistics { // the result output row number will always be no greater than the limit number - num_rows: Some(self.fetch * self.output_partitioning().partition_count()), - is_exact: false, - ..Default::default() + num_rows: Precision::Inexact( + self.fetch * self.output_partitioning().partition_count(), + ), + + column_statistics: col_stats, + total_byte_size: Precision::Absent, }, - } + }; + Ok(stats) } } @@ -442,7 +483,7 @@ impl LimitStream { match &poll { Poll::Ready(Some(Ok(batch))) => { - if batch.num_rows() > 0 && self.skip == 0 { + if batch.num_rows() > 0 { break poll; } else { // continue to poll input stream @@ -528,14 +569,15 @@ impl RecordBatchStream for LimitStream { #[cfg(test)] mod tests { - - use arrow_schema::Schema; - use common::collect; - use super::*; use crate::coalesce_partitions::CoalescePartitionsExec; - use crate::common; - use crate::test; + use crate::common::collect; + use crate::{common, test}; + + use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; + use arrow_schema::Schema; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr::PhysicalExpr; #[tokio::test] async fn limit() -> Result<()> { @@ -695,7 +737,7 @@ mod tests { } #[tokio::test] - async fn skip_3_fetch_10() -> Result<()> { + async fn skip_3_fetch_10_stats() -> Result<()> { // there are total of 100 rows, we skipped 3 rows (offset = 3) let row_count = skip_and_fetch(3, Some(10)).await?; assert_eq!(row_count, 10); @@ -728,10 +770,61 @@ mod tests { #[tokio::test] async fn test_row_number_statistics_for_global_limit() -> Result<()> { let row_count = row_number_statistics_for_global_limit(0, Some(10)).await?; - assert_eq!(row_count, Some(10)); + assert_eq!(row_count, Precision::Exact(10)); let row_count = row_number_statistics_for_global_limit(5, Some(10)).await?; - assert_eq!(row_count, Some(15)); + assert_eq!(row_count, Precision::Exact(10)); + + let row_count = row_number_statistics_for_global_limit(400, Some(10)).await?; + assert_eq!(row_count, Precision::Exact(0)); + + let row_count = row_number_statistics_for_global_limit(398, Some(10)).await?; + assert_eq!(row_count, Precision::Exact(2)); + + let row_count = row_number_statistics_for_global_limit(398, Some(1)).await?; + assert_eq!(row_count, Precision::Exact(1)); + + let row_count = row_number_statistics_for_global_limit(398, None).await?; + assert_eq!(row_count, Precision::Exact(2)); + + let row_count = + row_number_statistics_for_global_limit(0, Some(usize::MAX)).await?; + assert_eq!(row_count, Precision::Exact(400)); + + let row_count = + row_number_statistics_for_global_limit(398, Some(usize::MAX)).await?; + assert_eq!(row_count, Precision::Exact(2)); + + let row_count = + row_number_inexact_statistics_for_global_limit(0, Some(10)).await?; + assert_eq!(row_count, Precision::Inexact(10)); + + let row_count = + row_number_inexact_statistics_for_global_limit(5, Some(10)).await?; + assert_eq!(row_count, Precision::Inexact(10)); + + let row_count = + row_number_inexact_statistics_for_global_limit(400, Some(10)).await?; + assert_eq!(row_count, Precision::Inexact(0)); + + let row_count = + row_number_inexact_statistics_for_global_limit(398, Some(10)).await?; + assert_eq!(row_count, Precision::Inexact(2)); + + let row_count = + row_number_inexact_statistics_for_global_limit(398, Some(1)).await?; + assert_eq!(row_count, Precision::Inexact(1)); + + let row_count = row_number_inexact_statistics_for_global_limit(398, None).await?; + assert_eq!(row_count, Precision::Inexact(2)); + + let row_count = + row_number_inexact_statistics_for_global_limit(0, Some(usize::MAX)).await?; + assert_eq!(row_count, Precision::Inexact(400)); + + let row_count = + row_number_inexact_statistics_for_global_limit(398, Some(usize::MAX)).await?; + assert_eq!(row_count, Precision::Inexact(2)); Ok(()) } @@ -739,7 +832,7 @@ mod tests { #[tokio::test] async fn test_row_number_statistics_for_local_limit() -> Result<()> { let row_count = row_number_statistics_for_local_limit(4, 10).await?; - assert_eq!(row_count, Some(10)); + assert_eq!(row_count, Precision::Exact(10)); Ok(()) } @@ -747,7 +840,7 @@ mod tests { async fn row_number_statistics_for_global_limit( skip: usize, fetch: Option, - ) -> Result> { + ) -> Result> { let num_partitions = 4; let csv = test::scan_partitioned(num_partitions); @@ -756,20 +849,60 @@ mod tests { let offset = GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch); - Ok(offset.statistics().num_rows) + Ok(offset.statistics()?.num_rows) + } + + pub fn build_group_by( + input_schema: &SchemaRef, + columns: Vec, + ) -> PhysicalGroupBy { + let mut group_by_expr: Vec<(Arc, String)> = vec![]; + for column in columns.iter() { + group_by_expr.push((col(column, input_schema).unwrap(), column.to_string())); + } + PhysicalGroupBy::new_single(group_by_expr.clone()) + } + + async fn row_number_inexact_statistics_for_global_limit( + skip: usize, + fetch: Option, + ) -> Result> { + let num_partitions = 4; + let csv = test::scan_partitioned(num_partitions); + + assert_eq!(csv.output_partitioning().partition_count(), num_partitions); + + // Adding a "GROUP BY i" changes the input stats from Exact to Inexact. + let agg = AggregateExec::try_new( + AggregateMode::Final, + build_group_by(&csv.schema().clone(), vec!["i".to_string()]), + vec![], + vec![None], + csv.clone(), + csv.schema().clone(), + )?; + let agg_exec: Arc = Arc::new(agg); + + let offset = GlobalLimitExec::new( + Arc::new(CoalescePartitionsExec::new(agg_exec)), + skip, + fetch, + ); + + Ok(offset.statistics()?.num_rows) } async fn row_number_statistics_for_local_limit( num_partitions: usize, fetch: usize, - ) -> Result> { + ) -> Result> { let csv = test::scan_partitioned(num_partitions); assert_eq!(csv.output_partitioning().partition_count(), num_partitions); let offset = LocalLimitExec::new(csv, fetch); - Ok(offset.statistics().num_rows) + Ok(offset.statistics()?.num_rows) } /// Return a RecordBatch with a single array with row_count sz diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index b29c8e9c7bd9..7de474fda11c 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -17,23 +17,23 @@ //! Execution plan for reading in-memory batches of data +use std::any::Any; +use std::fmt; +use std::sync::Arc; +use std::task::{Context, Poll}; + use super::expressions::PhysicalSortExpr; use super::{ common, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; + use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use core::fmt; -use datafusion_common::{internal_err, project_schema, Result}; -use std::any::Any; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use crate::ordering_equivalence_properties_helper; -use datafusion_common::DataFusionError; +use datafusion_common::{internal_err, project_schema, DataFusionError, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{LexOrdering, OrderingEquivalenceProperties}; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; + use futures::Stream; /// Execution plan for reading in-memory batches of data @@ -55,7 +55,7 @@ impl fmt::Debug for MemoryExec { write!(f, "partitions: [...]")?; write!(f, "schema: {:?}", self.projected_schema)?; write!(f, "projection: {:?}", self.projection)?; - if let Some(sort_info) = &self.sort_information.get(0) { + if let Some(sort_info) = &self.sort_information.first() { write!(f, ", output_ordering: {:?}", sort_info)?; } Ok(()) @@ -77,11 +77,12 @@ impl DisplayAs for MemoryExec { .sort_information .first() .map(|output_ordering| { - let order_strings: Vec<_> = - output_ordering.iter().map(|e| e.to_string()).collect(); - format!(", output_ordering={}", order_strings.join(",")) + format!( + ", output_ordering={}", + PhysicalSortExpr::format_list(output_ordering) + ) }) - .unwrap_or_else(|| "".to_string()); + .unwrap_or_default(); write!( f, @@ -120,15 +121,20 @@ impl ExecutionPlan for MemoryExec { .map(|ordering| ordering.as_slice()) } - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - ordering_equivalence_properties_helper(self.schema(), &self.sort_information) + fn equivalence_properties(&self) -> EquivalenceProperties { + EquivalenceProperties::new_with_orderings(self.schema(), &self.sort_information) } fn with_new_children( self: Arc, - _: Vec>, + children: Vec>, ) -> Result> { - internal_err!("Children cannot be replaced in {self:?}") + // MemoryExec has no children + if children.is_empty() { + Ok(self) + } else { + internal_err!("Children cannot be replaced in {self:?}") + } } fn execute( @@ -144,12 +150,12 @@ impl ExecutionPlan for MemoryExec { } /// We recompute the statistics dynamically from the arrow metadata as it is pretty cheap to do so - fn statistics(&self) -> Statistics { - common::compute_record_batch_statistics( + fn statistics(&self) -> Result { + Ok(common::compute_record_batch_statistics( &self.partitions, &self.schema, self.projection.clone(), - ) + )) } } @@ -171,8 +177,16 @@ impl MemoryExec { }) } + pub fn partitions(&self) -> &[Vec] { + &self.partitions + } + + pub fn projection(&self) -> &Option> { + &self.projection + } + /// A memory table can be ordered by multiple expressions simultaneously. - /// `OrderingEquivalenceProperties` keeps track of expressions that describe the + /// [`EquivalenceProperties`] keeps track of expressions that describe the /// global ordering of the schema. These columns are not necessarily same; e.g. /// ```text /// ┌-------┐ @@ -185,14 +199,16 @@ impl MemoryExec { /// └---┴---┘ /// ``` /// where both `a ASC` and `b DESC` can describe the table ordering. With - /// `OrderingEquivalenceProperties`, we can keep track of these equivalences - /// and treat `a ASC` and `b DESC` as the same ordering requirement - /// by outputting the `a ASC` from output_ordering API - /// and add `b DESC` into `OrderingEquivalenceProperties` + /// [`EquivalenceProperties`], we can keep track of these equivalences + /// and treat `a ASC` and `b DESC` as the same ordering requirement. pub fn with_sort_information(mut self, sort_information: Vec) -> Self { self.sort_information = sort_information; self } + + pub fn original_schema(&self) -> SchemaRef { + self.schema.clone() + } } /// Iterator over batches @@ -260,12 +276,14 @@ impl RecordBatchStream for MemoryStream { #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::memory::MemoryExec; use crate::ExecutionPlan; + use arrow_schema::{DataType, Field, Schema, SortOptions}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; - use std::sync::Arc; #[test] fn test_memory_order_eq() -> datafusion_common::Result<()> { @@ -294,11 +312,8 @@ mod tests { .with_sort_information(sort_information); assert_eq!(mem_exec.output_ordering().unwrap(), expected_output_order); - let order_eq = mem_exec.ordering_equivalence_properties(); - assert!(order_eq - .oeq_class() - .map(|class| class.contains(&expected_order_eq)) - .unwrap_or(false)); + let eq_properties = mem_exec.equivalence_properties(); + assert!(eq_properties.oeq_class().contains(&expected_order_eq)); Ok(()) } } diff --git a/datafusion/physical-plan/src/ordering.rs b/datafusion/physical-plan/src/ordering.rs new file mode 100644 index 000000000000..047f89eef193 --- /dev/null +++ b/datafusion/physical-plan/src/ordering.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Specifies how the input to an aggregation or window operator is ordered +/// relative to their `GROUP BY` or `PARTITION BY` expressions. +/// +/// For example, if the existing ordering is `[a ASC, b ASC, c ASC]` +/// +/// ## Window Functions +/// - A `PARTITION BY b` clause can use `Linear` mode. +/// - A `PARTITION BY a, c` or a `PARTITION BY c, a` can use +/// `PartiallySorted([0])` or `PartiallySorted([1])` modes, respectively. +/// (The vector stores the index of `a` in the respective PARTITION BY expression.) +/// - A `PARTITION BY a, b` or a `PARTITION BY b, a` can use `Sorted` mode. +/// +/// ## Aggregations +/// - A `GROUP BY b` clause can use `Linear` mode. +/// - A `GROUP BY a, c` or a `GROUP BY BY c, a` can use +/// `PartiallySorted([0])` or `PartiallySorted([1])` modes, respectively. +/// (The vector stores the index of `a` in the respective PARTITION BY expression.) +/// - A `GROUP BY a, b` or a `GROUP BY b, a` can use `Sorted` mode. +/// +/// Note these are the same examples as above, but with `GROUP BY` instead of +/// `PARTITION BY` to make the examples easier to read. +#[derive(Debug, Clone, PartialEq)] +pub enum InputOrderMode { + /// There is no partial permutation of the expressions satisfying the + /// existing ordering. + Linear, + /// There is a partial permutation of the expressions satisfying the + /// existing ordering. Indices describing the longest partial permutation + /// are stored in the vector. + PartiallySorted(Vec), + /// There is a (full) permutation of the expressions satisfying the + /// existing ordering. + Sorted, +} diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs new file mode 100644 index 000000000000..3ab3de62f37a --- /dev/null +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! EmptyRelation produce_one_row=true execution plan + +use std::any::Any; +use std::sync::Arc; + +use super::expressions::PhysicalSortExpr; +use super::{common, DisplayAs, SendableRecordBatchStream, Statistics}; +use crate::{memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning}; + +use arrow::array::{ArrayRef, NullArray}; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use arrow_array::RecordBatchOptions; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_execution::TaskContext; + +use log::trace; + +/// Execution plan for empty relation with produce_one_row=true +#[derive(Debug)] +pub struct PlaceholderRowExec { + /// The schema for the produced row + schema: SchemaRef, + /// Number of partitions + partitions: usize, +} + +impl PlaceholderRowExec { + /// Create a new PlaceholderRowExec + pub fn new(schema: SchemaRef) -> Self { + PlaceholderRowExec { + schema, + partitions: 1, + } + } + + /// Create a new PlaceholderRowExecPlaceholderRowExec with specified partition number + pub fn with_partitions(mut self, partitions: usize) -> Self { + self.partitions = partitions; + self + } + + fn data(&self) -> Result> { + Ok({ + let n_field = self.schema.fields.len(); + vec![RecordBatch::try_new_with_options( + Arc::new(Schema::new( + (0..n_field) + .map(|i| { + Field::new(format!("placeholder_{i}"), DataType::Null, true) + }) + .collect::(), + )), + (0..n_field) + .map(|_i| { + let ret: ArrayRef = Arc::new(NullArray::new(1)); + ret + }) + .collect(), + // Even if column number is empty we can generate single row. + &RecordBatchOptions::new().with_row_count(Some(1)), + )?] + }) + } +} + +impl DisplayAs for PlaceholderRowExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "PlaceholderRowExec") + } + } + } +} + +impl ExecutionPlan for PlaceholderRowExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec> { + vec![] + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.partitions) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(Arc::new(PlaceholderRowExec::new(self.schema.clone()))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + trace!("Start PlaceholderRowExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + + if partition >= self.partitions { + return internal_err!( + "PlaceholderRowExec invalid partition {} (expected less than {})", + partition, + self.partitions + ); + } + + Ok(Box::pin(MemoryStream::try_new( + self.data()?, + self.schema.clone(), + None, + )?)) + } + + fn statistics(&self) -> Result { + let batch = self + .data() + .expect("Create single row placeholder RecordBatch should not fail"); + Ok(common::compute_record_batch_statistics( + &[batch], + &self.schema, + None, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::with_new_children_if_necessary; + use crate::{common, test}; + + #[test] + fn with_new_children() -> Result<()> { + let schema = test::aggr_test_schema(); + + let placeholder = Arc::new(PlaceholderRowExec::new(schema)); + + let placeholder_2 = + with_new_children_if_necessary(placeholder.clone(), vec![])?.into(); + assert_eq!(placeholder.schema(), placeholder_2.schema()); + + let too_many_kids = vec![placeholder_2]; + assert!( + with_new_children_if_necessary(placeholder, too_many_kids).is_err(), + "expected error when providing list of kids" + ); + Ok(()) + } + + #[tokio::test] + async fn invalid_execute() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + let placeholder = PlaceholderRowExec::new(schema); + + // ask for the wrong partition + assert!(placeholder.execute(1, task_ctx.clone()).is_err()); + assert!(placeholder.execute(20, task_ctx).is_err()); + Ok(()) + } + + #[tokio::test] + async fn produce_one_row() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + let placeholder = PlaceholderRowExec::new(schema); + + let iter = placeholder.execute(0, task_ctx)?; + let batches = common::collect(iter).await?; + + // should have one item + assert_eq!(batches.len(), 1); + + Ok(()) + } + + #[tokio::test] + async fn produce_one_row_multiple_partition() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + let partitions = 3; + let placeholder = PlaceholderRowExec::new(schema).with_partitions(partitions); + + for n in 0..partitions { + let iter = placeholder.execute(n, task_ctx.clone())?; + let batches = common::collect(iter).await?; + + // should have one item + assert_eq!(batches.len(), 1); + } + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 4fc48e971ca9..cc2ab62049ed 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -30,26 +30,23 @@ use super::expressions::{Column, PhysicalSortExpr}; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{DisplayAs, RecordBatchStream, SendableRecordBatchStream, Statistics}; use crate::{ - ColumnStatistics, DisplayFormatType, EquivalenceProperties, ExecutionPlan, - Partitioning, PhysicalExpr, + ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, }; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use datafusion_common::stats::Precision; use datafusion_common::Result; use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::{Literal, UnKnownColumn}; -use datafusion_physical_expr::{ - normalize_out_expr_with_columns_map, project_equivalence_properties, - project_ordering_equivalence_properties, OrderingEquivalenceProperties, -}; +use datafusion_physical_expr::EquivalenceProperties; -use datafusion_physical_expr::utils::find_orderings_of_exprs; use futures::stream::{Stream, StreamExt}; use log::trace; /// Execution plan for a projection -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ProjectionExec { /// The projection expressions stored as tuples of (expression, output column name) pub(crate) expr: Vec<(Arc, String)>, @@ -59,15 +56,11 @@ pub struct ProjectionExec { input: Arc, /// The output ordering output_ordering: Option>, - /// The columns map used to normalize out expressions like Partitioning and PhysicalSortExpr - /// The key is the column from the input schema and the values are the columns from the output schema - columns_map: HashMap>, + /// The mapping used to normalize expressions like Partitioning and + /// PhysicalSortExpr that maps input to output + projection_mapping: ProjectionMapping, /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// Expressions' normalized orderings (as given by the output ordering API - /// and normalized with respect to equivalence classes of input plan). The - /// projected expressions are mapped by their indices to this vector. - orderings: Vec>, } impl ProjectionExec { @@ -99,65 +92,20 @@ impl ProjectionExec { input_schema.metadata().clone(), )); - // construct a map from the input columns to the output columns of the Projection - let mut columns_map: HashMap> = HashMap::new(); - for (expr_idx, (expression, name)) in expr.iter().enumerate() { - if let Some(column) = expression.as_any().downcast_ref::() { - // For some executors, logical and physical plan schema fields - // are not the same. The information in a `Column` comes from - // the logical plan schema. Therefore, to produce correct results - // we use the field in the input schema with the same index. This - // corresponds to the physical plan `Column`. - let idx = column.index(); - let matching_input_field = input_schema.field(idx); - let matching_input_column = Column::new(matching_input_field.name(), idx); - let entry = columns_map - .entry(matching_input_column) - .or_insert_with(Vec::new); - entry.push(Column::new(name, expr_idx)); - }; - } - - // Output Ordering need to respect the alias - let child_output_ordering = input.output_ordering(); - let output_ordering = match child_output_ordering { - Some(sort_exprs) => { - let normalized_exprs = sort_exprs - .iter() - .map(|sort_expr| { - let expr = normalize_out_expr_with_columns_map( - sort_expr.expr.clone(), - &columns_map, - ); - PhysicalSortExpr { - expr, - options: sort_expr.options, - } - }) - .collect::>(); - Some(normalized_exprs) - } - None => None, - }; - - let orderings = find_orderings_of_exprs( - &expr, - input.output_ordering(), - input.equivalence_properties(), - input.ordering_equivalence_properties(), - )?; + // construct a map from the input expressions to the output expression of the Projection + let projection_mapping = ProjectionMapping::try_new(&expr, &input_schema)?; - let output_ordering = - validate_output_ordering(output_ordering, &orderings, &expr); + let input_eqs = input.equivalence_properties(); + let project_eqs = input_eqs.project(&projection_mapping, schema.clone()); + let output_ordering = project_eqs.oeq_class().output_ordering(); Ok(Self { expr, schema, input, output_ordering, - columns_map, + projection_mapping, metrics: ExecutionPlanMetricsSet::new(), - orderings, }) } @@ -225,18 +173,21 @@ impl ExecutionPlan for ProjectionExec { fn output_partitioning(&self) -> Partitioning { // Output partition need to respect the alias let input_partition = self.input.output_partitioning(); - match input_partition { - Partitioning::Hash(exprs, part) => { - let normalized_exprs = exprs - .into_iter() - .map(|expr| { - normalize_out_expr_with_columns_map(expr, &self.columns_map) - }) - .collect::>(); - - Partitioning::Hash(normalized_exprs, part) - } - _ => input_partition, + let input_eq_properties = self.input.equivalence_properties(); + if let Partitioning::Hash(exprs, part) = input_partition { + let normalized_exprs = exprs + .into_iter() + .map(|expr| { + input_eq_properties + .project_expr(&expr, &self.projection_mapping) + .unwrap_or_else(|| { + Arc::new(UnKnownColumn::new(&expr.to_string())) + }) + }) + .collect(); + Partitioning::Hash(normalized_exprs, part) + } else { + input_partition } } @@ -250,58 +201,17 @@ impl ExecutionPlan for ProjectionExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - let mut new_properties = EquivalenceProperties::new(self.schema()); - project_equivalence_properties( - self.input.equivalence_properties(), - &self.columns_map, - &mut new_properties, - ); - new_properties - } - - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - let mut new_properties = OrderingEquivalenceProperties::new(self.schema()); - if self.output_ordering.is_none() { - // If there is no output ordering, return an "empty" equivalence set: - return new_properties; - } - - let input_oeq = self.input().ordering_equivalence_properties(); - - project_ordering_equivalence_properties( - input_oeq, - &self.columns_map, - &mut new_properties, - ); - - if let Some(leading_ordering) = self - .output_ordering - .as_ref() - .map(|output_ordering| &output_ordering[0]) - { - for order in self.orderings.iter().flatten() { - if !order.eq(leading_ordering) - && !new_properties.satisfies_leading_ordering(order) - { - new_properties.add_equal_conditions(( - &vec![leading_ordering.clone()], - &vec![order.clone()], - )); - } - } - } - - new_properties + self.input + .equivalence_properties() + .project(&self.projection_mapping, self.schema()) } fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { - Ok(Arc::new(ProjectionExec::try_new( - self.expr.clone(), - children[0].clone(), - )?)) + ProjectionExec::try_new(self.expr.clone(), children.swap_remove(0)) + .map(|p| Arc::new(p) as _) } fn benefits_from_input_partitioning(&self) -> Vec { @@ -332,119 +242,61 @@ impl ExecutionPlan for ProjectionExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - stats_projection( - self.input.statistics(), + fn statistics(&self) -> Result { + Ok(stats_projection( + self.input.statistics()?, self.expr.iter().map(|(e, _)| Arc::clone(e)), self.schema.clone(), - ) + )) } } -/// This function takes the current `output_ordering`, the `orderings` based on projected expressions, -/// and the `expr` representing the projected expressions themselves. It aims to ensure that the output -/// ordering is valid and correctly corresponds to the projected columns. -/// -/// If the leading expression in the `output_ordering` is an [`UnKnownColumn`], it indicates that the column -/// referenced in the ordering is not found among the projected expressions. In such cases, this function -/// attempts to create a new output ordering by referring to valid columns from the leftmost side of the -/// expressions that have an ordering specified. -fn validate_output_ordering( - output_ordering: Option>, - orderings: &[Option], - expr: &[(Arc, String)], -) -> Option> { - output_ordering.and_then(|ordering| { - // If the leading expression is invalid column, change output - // ordering of the projection so that it refers to valid columns if - // possible. - if ordering[0].expr.as_any().is::() { - for (idx, order) in orderings.iter().enumerate() { - if let Some(sort_expr) = order { - let (_, col_name) = &expr[idx]; - return Some(vec![PhysicalSortExpr { - expr: Arc::new(Column::new(col_name, idx)), - options: sort_expr.options, - }]); - } - } - None - } else { - Some(ordering) - } - }) -} - /// If e is a direct column reference, returns the field level /// metadata for that field, if any. Otherwise returns None fn get_field_metadata( e: &Arc, input_schema: &Schema, ) -> Option> { - let name = if let Some(column) = e.as_any().downcast_ref::() { - column.name() - } else { - return None; - }; - - input_schema - .field_with_name(name) - .ok() - .map(|f| f.metadata().clone()) + // Look up field by index in schema (not NAME as there can be more than one + // column with the same name) + e.as_any() + .downcast_ref::() + .map(|column| input_schema.field(column.index()).metadata()) + .cloned() } fn stats_projection( - stats: Statistics, + mut stats: Statistics, exprs: impl Iterator>, schema: SchemaRef, ) -> Statistics { - let inner_exprs = exprs.collect::>(); - let column_statistics = stats.column_statistics.map(|input_col_stats| { - inner_exprs - .clone() - .into_iter() - .map(|e| { - if let Some(col) = e.as_any().downcast_ref::() { - input_col_stats[col.index()].clone() - } else { - // TODO stats: estimate more statistics from expressions - // (expressions should compute their statistics themselves) - ColumnStatistics::default() - } - }) - .collect() - }); - - let primitive_row_size = inner_exprs - .into_iter() - .map(|e| match e.data_type(schema.as_ref()) { - Ok(data_type) => data_type.primitive_width(), - Err(_) => None, - }) - .try_fold(0usize, |init, v| v.map(|value| init + value)); - - match (primitive_row_size, stats.num_rows) { - (Some(row_size), Some(row_count)) => { - Statistics { - is_exact: stats.is_exact, - num_rows: stats.num_rows, - column_statistics, - // Use the row_size * row_count as the total byte size - total_byte_size: Some(row_size * row_count), - } - } - _ => { - Statistics { - is_exact: stats.is_exact, - num_rows: stats.num_rows, - column_statistics, - // TODO stats: knowing the type of the new columns we can guess the output size - // If we can't get the exact statistics for the project - // Before we get the exact result, we just use the child status - total_byte_size: stats.total_byte_size, + let mut primitive_row_size = 0; + let mut primitive_row_size_possible = true; + let mut column_statistics = vec![]; + for expr in exprs { + let col_stats = if let Some(col) = expr.as_any().downcast_ref::() { + stats.column_statistics[col.index()].clone() + } else { + // TODO stats: estimate more statistics from expressions + // (expressions should compute their statistics themselves) + ColumnStatistics::new_unknown() + }; + column_statistics.push(col_stats); + if let Ok(data_type) = expr.data_type(&schema) { + if let Some(value) = data_type.primitive_width() { + primitive_row_size += value; + continue; } } + primitive_row_size_possible = false; } + + if primitive_row_size_possible { + stats.total_byte_size = + Precision::Exact(primitive_row_size).multiply(&stats.num_rows); + } + stats.column_statistics = column_statistics; + stats } impl ProjectionStream { @@ -454,8 +306,10 @@ impl ProjectionStream { let arrays = self .expr .iter() - .map(|expr| expr.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect::>>()?; if arrays.is_empty() { @@ -511,6 +365,7 @@ mod tests { use crate::common::collect; use crate::expressions; use crate::test; + use arrow_schema::DataType; use datafusion_common::ScalarValue; @@ -531,29 +386,28 @@ mod tests { fn get_stats() -> Statistics { Statistics { - is_exact: true, - num_rows: Some(5), - total_byte_size: Some(23), - column_statistics: Some(vec![ + num_rows: Precision::Exact(5), + total_byte_size: Precision::Exact(23), + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(5), - max_value: Some(ScalarValue::Int64(Some(21))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: Some(0), + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(0), }, ColumnStatistics { - distinct_count: Some(1), - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: Some(3), + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Exact(3), }, ColumnStatistics { - distinct_count: None, - max_value: Some(ScalarValue::Float32(Some(1.1))), - min_value: Some(ScalarValue::Float32(Some(0.1))), - null_count: None, + distinct_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))), + min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))), + null_count: Precision::Absent, }, - ]), + ], } } @@ -576,23 +430,22 @@ mod tests { let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)); let expected = Statistics { - is_exact: true, - num_rows: Some(5), - total_byte_size: Some(23), - column_statistics: Some(vec![ + num_rows: Precision::Exact(5), + total_byte_size: Precision::Exact(23), + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(1), - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: Some(3), + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Exact(3), }, ColumnStatistics { - distinct_count: Some(5), - max_value: Some(ScalarValue::Int64(Some(21))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: Some(0), + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(0), }, - ]), + ], }; assert_eq!(result, expected); @@ -611,23 +464,22 @@ mod tests { let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)); let expected = Statistics { - is_exact: true, - num_rows: Some(5), - total_byte_size: Some(60), - column_statistics: Some(vec![ + num_rows: Precision::Exact(5), + total_byte_size: Precision::Exact(60), + column_statistics: vec![ ColumnStatistics { - distinct_count: None, - max_value: Some(ScalarValue::Float32(Some(1.1))), - min_value: Some(ScalarValue::Float32(Some(0.1))), - null_count: None, + distinct_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))), + min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))), + null_count: Precision::Absent, }, ColumnStatistics { - distinct_count: Some(5), - max_value: Some(ScalarValue::Int64(Some(21))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: Some(0), + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(0), }, - ]), + ], }; assert_eq!(result, expected); diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 14b54dc0614d..07693f747fee 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -15,45 +15,43 @@ // specific language governing permissions and limitations // under the License. -//! The repartition operator maps N input partitions to M output partitions based on a -//! partitioning scheme (according to flag `preserve_order` ordering can be preserved during -//! repartitioning if its input is ordered). +//! This file implements the [`RepartitionExec`] operator, which maps N input +//! partitions to M output partitions based on a partitioning scheme, optionally +//! maintaining the order of the input rows in the output. use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use std::{any::Any, vec}; +use arrow::array::{ArrayRef, UInt64Builder}; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use futures::stream::Stream; +use futures::{FutureExt, StreamExt}; +use hashbrown::HashMap; +use log::trace; +use parking_lot::Mutex; +use tokio::task::JoinHandle; + +use datafusion_common::{arrow_datafusion_err, not_impl_err, DataFusionError, Result}; +use datafusion_execution::memory_pool::MemoryConsumer; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; + use crate::common::transpose; use crate::hash_utils::create_hashes; use crate::metrics::BaselineMetrics; use crate::repartition::distributor_channels::{channels, partition_aware_channels}; use crate::sorts::streaming_merge; -use crate::{ - DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics, -}; - -use self::distributor_channels::{DistributionReceiver, DistributionSender}; +use crate::{DisplayFormatType, ExecutionPlan, Partitioning, Statistics}; use super::common::{AbortOnDropMany, AbortOnDropSingle, SharedMemoryReservation}; use super::expressions::PhysicalSortExpr; use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{DisplayAs, RecordBatchStream, SendableRecordBatchStream}; -use arrow::array::{ArrayRef, UInt64Builder}; -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use datafusion_common::{not_impl_err, DataFusionError, Result}; -use datafusion_execution::memory_pool::MemoryConsumer; -use datafusion_execution::TaskContext; -use datafusion_physical_expr::{OrderingEquivalenceProperties, PhysicalExpr}; - -use futures::stream::Stream; -use futures::{FutureExt, StreamExt}; -use hashbrown::HashMap; -use log::trace; -use parking_lot::Mutex; -use tokio::task::JoinHandle; +use self::distributor_channels::{DistributionReceiver, DistributionSender}; mod distributor_channels; @@ -171,9 +169,7 @@ impl BatchPartitioner { let arrays = exprs .iter() - .map(|expr| { - Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())) - }) + .map(|expr| expr.evaluate(&batch)?.into_array(batch.num_rows())) .collect::>>()?; hash_buffer.clear(); @@ -204,7 +200,7 @@ impl BatchPartitioner { .iter() .map(|c| { arrow::compute::take(c.as_ref(), &indices, None) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect::>>()?; @@ -238,7 +234,7 @@ impl BatchPartitioner { /// /// # Background /// -/// DataFusion, like most other commercial systems, with the the +/// DataFusion, like most other commercial systems, with the /// notable exception of DuckDB, uses the "Exchange Operator" based /// approach to parallelism which works well in practice given /// sufficient care in implementation. @@ -283,8 +279,9 @@ impl BatchPartitioner { /// /// # Output Ordering /// -/// No guarantees are made about the order of the resulting -/// partitions unless `preserve_order` is set. +/// If more than one stream is being repartitioned, the output will be some +/// arbitrary interleaving (and thus unordered) unless +/// [`Self::with_preserve_order`] specifies otherwise. /// /// # Footnote /// @@ -308,7 +305,8 @@ pub struct RepartitionExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// Boolean flag to decide whether to preserve ordering + /// Boolean flag to decide whether to preserve ordering. If true means + /// `SortPreservingRepartitionExec`, false means `RepartitionExec`. preserve_order: bool, } @@ -370,13 +368,9 @@ impl RepartitionExec { self.preserve_order } - /// Get name of the Executor + /// Get name used to display this Exec pub fn name(&self) -> &str { - if self.preserve_order { - "SortPreservingRepartitionExec" - } else { - "RepartitionExec" - } + "RepartitionExec" } } @@ -394,7 +388,20 @@ impl DisplayAs for RepartitionExec { self.name(), self.partitioning, self.input.output_partitioning().partition_count() - ) + )?; + + if self.preserve_order { + write!(f, ", preserve_order=true")?; + } + + if let Some(sort_exprs) = self.sort_exprs() { + write!( + f, + ", sort_exprs={}", + PhysicalSortExpr::format_list(sort_exprs) + )?; + } + Ok(()) } } } @@ -417,11 +424,13 @@ impl ExecutionPlan for RepartitionExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { - let repartition = - RepartitionExec::try_new(children[0].clone(), self.partitioning.clone())? - .with_preserve_order(self.preserve_order); + let mut repartition = + RepartitionExec::try_new(children.swap_remove(0), self.partitioning.clone())?; + if self.preserve_order { + repartition = repartition.with_preserve_order(); + } Ok(Arc::new(repartition)) } @@ -458,11 +467,12 @@ impl ExecutionPlan for RepartitionExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - self.input.equivalence_properties() - } - - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - self.input.ordering_equivalence_properties() + let mut result = self.input.equivalence_properties(); + // If the ordering is lost, reset the ordering equivalence class. + if !self.maintains_input_order()[0] { + result.clear_orderings(); + } + result } fn execute( @@ -576,8 +586,8 @@ impl ExecutionPlan for RepartitionExec { .collect::>(); // Note that receiver size (`rx.len()`) and `num_input_partitions` are same. - // Get existing ordering: - let sort_exprs = self.input.output_ordering().unwrap_or(&[]); + // Get existing ordering to use for merging + let sort_exprs = self.sort_exprs().unwrap_or(&[]); // Merge streams (while preserving ordering) coming from // input partitions to this partition: @@ -610,13 +620,15 @@ impl ExecutionPlan for RepartitionExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { self.input.statistics() } } impl RepartitionExec { - /// Create a new RepartitionExec + /// Create a new RepartitionExec, that produces output `partitioning`, and + /// does not preserve the order of the input (see [`Self::with_preserve_order`] + /// for more details) pub fn try_new( input: Arc, partitioning: Partitioning, @@ -633,19 +645,32 @@ impl RepartitionExec { }) } - /// Set Order preserving flag - pub fn with_preserve_order(mut self, preserve_order: bool) -> Self { - // Set "preserve order" mode only if the input partition count is larger than 1 - // Because in these cases naive `RepartitionExec` cannot maintain ordering. Using - // `SortPreservingRepartitionExec` is necessity. However, when input partition number - // is 1, `RepartitionExec` can maintain ordering. In this case, we don't need to use - // `SortPreservingRepartitionExec` variant to maintain ordering. - if self.input.output_partitioning().partition_count() > 1 { - self.preserve_order = preserve_order - } + /// Specify if this reparititoning operation should preserve the order of + /// rows from its input when producing output. Preserving order is more + /// expensive at runtime, so should only be set if the output of this + /// operator can take advantage of it. + /// + /// If the input is not ordered, or has only one partition, this is a no op, + /// and the node remains a `RepartitionExec`. + pub fn with_preserve_order(mut self) -> Self { + self.preserve_order = + // If the input isn't ordered, there is no ordering to preserve + self.input.output_ordering().is_some() && + // if there is only one input partition, merging is not required + // to maintain order + self.input.output_partitioning().partition_count() > 1; self } + /// Return the sort expressions that are used to merge + fn sort_exprs(&self) -> Option<&[PhysicalSortExpr]> { + if self.preserve_order { + self.input.output_ordering() + } else { + None + } + } + /// Pulls data from the specified input plan, feeding it to the /// output partitions based on the desired partitioning /// @@ -893,7 +918,19 @@ impl RecordBatchStream for PerPartitionStream { #[cfg(test)] mod tests { - use super::*; + use std::collections::HashSet; + + use arrow::array::{ArrayRef, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use arrow_array::UInt32Array; + use futures::FutureExt; + use tokio::task::JoinHandle; + + use datafusion_common::cast::as_string_array; + use datafusion_common::{assert_batches_sorted_eq, exec_err}; + use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::{ test::{ assert_is_pending, @@ -904,16 +941,8 @@ mod tests { }, {collect, expressions::col, memory::MemoryExec}, }; - use arrow::array::{ArrayRef, StringArray}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use arrow_array::UInt32Array; - use datafusion_common::cast::as_string_array; - use datafusion_common::{assert_batches_sorted_eq, exec_err}; - use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; - use futures::FutureExt; - use std::collections::HashSet; - use tokio::task::JoinHandle; + + use super::*; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1385,9 +1414,8 @@ mod tests { // pull partitions for i in 0..exec.partitioning.partition_count() { let mut stream = exec.execute(i, task_ctx.clone())?; - let err = DataFusionError::ArrowError( - stream.next().await.unwrap().unwrap_err().into(), - ); + let err = + arrow_datafusion_err!(stream.next().await.unwrap().unwrap_err().into()); let err = err.find_root(); assert!( matches!(err, DataFusionError::ResourcesExhausted(_)), @@ -1414,3 +1442,129 @@ mod tests { .unwrap() } } + +#[cfg(test)] +mod test { + use arrow_schema::{DataType, Field, Schema, SortOptions}; + + use datafusion_physical_expr::expressions::col; + + use crate::memory::MemoryExec; + use crate::union::UnionExec; + + use super::*; + + /// Asserts that the plan is as expected + /// + /// `$EXPECTED_PLAN_LINES`: input plan + /// `$PLAN`: the plan to optimized + /// + macro_rules! assert_plan { + ($EXPECTED_PLAN_LINES: expr, $PLAN: expr) => { + let physical_plan = $PLAN; + let formatted = crate::displayable(&physical_plan).indent(true).to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + + let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES + .iter().map(|s| *s).collect(); + + assert_eq!( + expected_plan_lines, actual, + "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n" + ); + }; + } + + #[tokio::test] + async fn test_preserve_order() -> Result<()> { + let schema = test_schema(); + let sort_exprs = sort_exprs(&schema); + let source1 = sorted_memory_exec(&schema, sort_exprs.clone()); + let source2 = sorted_memory_exec(&schema, sort_exprs); + // output has multiple partitions, and is sorted + let union = UnionExec::new(vec![source1, source2]); + let exec = + RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(); + + // Repartition should preserve order + let expected_plan = [ + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC", + " UnionExec", + " MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", + " MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", + ]; + assert_plan!(expected_plan, exec); + Ok(()) + } + + #[tokio::test] + async fn test_preserve_order_one_partition() -> Result<()> { + let schema = test_schema(); + let sort_exprs = sort_exprs(&schema); + let source = sorted_memory_exec(&schema, sort_exprs); + // output is sorted, but has only a single partition, so no need to sort + let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(); + + // Repartition should not preserve order + let expected_plan = [ + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", + ]; + assert_plan!(expected_plan, exec); + Ok(()) + } + + #[tokio::test] + async fn test_preserve_order_input_not_sorted() -> Result<()> { + let schema = test_schema(); + let source1 = memory_exec(&schema); + let source2 = memory_exec(&schema); + // output has multiple partitions, but is not sorted + let union = UnionExec::new(vec![source1, source2]); + let exec = + RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(); + + // Repartition should not preserve order, as there is no order to preserve + let expected_plan = [ + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + " UnionExec", + " MemoryExec: partitions=1, partition_sizes=[0]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + assert_plan!(expected_plan, exec); + Ok(()) + } + + fn test_schema() -> Arc { + Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])) + } + + fn sort_exprs(schema: &Schema) -> Vec { + let options = SortOptions::default(); + vec![PhysicalSortExpr { + expr: col("c0", schema).unwrap(), + options, + }] + } + + fn memory_exec(schema: &SchemaRef) -> Arc { + Arc::new(MemoryExec::try_new(&[vec![]], schema.clone(), None).unwrap()) + } + + fn sorted_memory_exec( + schema: &SchemaRef, + sort_exprs: Vec, + ) -> Arc { + Arc::new( + MemoryExec::try_new(&[vec![]], schema.clone(), None) + .unwrap() + .with_sort_information(vec![sort_exprs]), + ) + } +} diff --git a/datafusion/physical-plan/src/sorts/cursor.rs b/datafusion/physical-plan/src/sorts/cursor.rs index baa417649fb0..df90c97faf68 100644 --- a/datafusion/physical-plan/src/sorts/cursor.rs +++ b/datafusion/physical-plan/src/sorts/cursor.rs @@ -15,125 +15,160 @@ // specific language governing permissions and limitations // under the License. -use crate::sorts::sort::SortOptions; +use std::cmp::Ordering; + use arrow::buffer::ScalarBuffer; +use arrow::compute::SortOptions; use arrow::datatypes::ArrowNativeTypeOp; -use arrow::row::{Row, Rows}; +use arrow::row::Rows; use arrow_array::types::ByteArrayType; -use arrow_array::{Array, ArrowPrimitiveType, GenericByteArray, PrimitiveArray}; +use arrow_array::{ + Array, ArrowPrimitiveType, GenericByteArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow_buffer::{Buffer, OffsetBuffer}; use datafusion_execution::memory_pool::MemoryReservation; -use std::cmp::Ordering; -/// A [`Cursor`] for [`Rows`] -pub struct RowCursor { - cur_row: usize, - num_rows: usize, +/// A comparable collection of values for use with [`Cursor`] +/// +/// This is a trait as there are several specialized implementations, such as for +/// single columns or for normalized multi column keys ([`Rows`]) +pub trait CursorValues { + fn len(&self) -> usize; - rows: Rows, + /// Returns true if `l[l_idx] == r[r_idx]` + fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool; - /// Tracks for the memory used by in the `Rows` of this - /// cursor. Freed on drop - #[allow(dead_code)] - reservation: MemoryReservation, + /// Returns comparison of `l[l_idx]` and `r[r_idx]` + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering; } -impl std::fmt::Debug for RowCursor { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("SortKeyCursor") - .field("cur_row", &self.cur_row) - .field("num_rows", &self.num_rows) - .finish() - } +/// A comparable cursor, used by sort operations +/// +/// A `Cursor` is a pointer into a collection of rows, stored in +/// [`CursorValues`] +/// +/// ```text +/// +/// ┌───────────────────────┐ +/// │ │ ┌──────────────────────┐ +/// │ ┌─────────┐ ┌─────┐ │ ─ ─ ─ ─│ Cursor │ +/// │ │ 1 │ │ A │ │ │ └──────────────────────┘ +/// │ ├─────────┤ ├─────┤ │ +/// │ │ 2 │ │ A │◀─ ┼ ─ ┘ Cursor tracks an +/// │ └─────────┘ └─────┘ │ offset within a +/// │ ... ... │ CursorValues +/// │ │ +/// │ ┌─────────┐ ┌─────┐ │ +/// │ │ 3 │ │ E │ │ +/// │ └─────────┘ └─────┘ │ +/// │ │ +/// │ CursorValues │ +/// └───────────────────────┘ +/// +/// +/// Store logical rows using +/// one of several formats, +/// with specialized +/// implementations +/// depending on the column +/// types +#[derive(Debug)] +pub struct Cursor { + offset: usize, + values: T, } -impl RowCursor { - /// Create a new SortKeyCursor from `rows` and a `reservation` - /// that tracks its memory. - /// - /// Panic's if the reservation is not for exactly `rows.size()` - /// bytes - pub fn new(rows: Rows, reservation: MemoryReservation) -> Self { - assert_eq!( - rows.size(), - reservation.size(), - "memory reservation mismatch" - ); - Self { - cur_row: 0, - num_rows: rows.num_rows(), - rows, - reservation, - } +impl Cursor { + /// Create a [`Cursor`] from the given [`CursorValues`] + pub fn new(values: T) -> Self { + Self { offset: 0, values } + } + + /// Returns true if there are no more rows in this cursor + pub fn is_finished(&self) -> bool { + self.offset == self.values.len() } - /// Returns the current row - fn current(&self) -> Row<'_> { - self.rows.row(self.cur_row) + /// Advance the cursor, returning the previous row index + pub fn advance(&mut self) -> usize { + let t = self.offset; + self.offset += 1; + t } } -impl PartialEq for RowCursor { +impl PartialEq for Cursor { fn eq(&self, other: &Self) -> bool { - self.current() == other.current() + T::eq(&self.values, self.offset, &other.values, other.offset) } } -impl Eq for RowCursor {} +impl Eq for Cursor {} -impl PartialOrd for RowCursor { +impl PartialOrd for Cursor { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Ord for RowCursor { +impl Ord for Cursor { fn cmp(&self, other: &Self) -> Ordering { - self.current().cmp(&other.current()) + T::compare(&self.values, self.offset, &other.values, other.offset) } } -/// A cursor into a sorted batch of rows -pub trait Cursor: Ord { - /// Returns true if there are no more rows in this cursor - fn is_finished(&self) -> bool; +/// Implements [`CursorValues`] for [`Rows`] +/// +/// Used for sorting when there are multiple columns in the sort key +#[derive(Debug)] +pub struct RowValues { + rows: Rows, - /// Advance the cursor, returning the previous row index - fn advance(&mut self) -> usize; + /// Tracks for the memory used by in the `Rows` of this + /// cursor. Freed on drop + #[allow(dead_code)] + reservation: MemoryReservation, } -impl Cursor for RowCursor { - #[inline] - fn is_finished(&self) -> bool { - self.num_rows == self.cur_row +impl RowValues { + /// Create a new [`RowValues`] from `rows` and a `reservation` + /// that tracks its memory. There must be at least one row + /// + /// Panics if the reservation is not for exactly `rows.size()` + /// bytes or if `rows` is empty. + pub fn new(rows: Rows, reservation: MemoryReservation) -> Self { + assert_eq!( + rows.size(), + reservation.size(), + "memory reservation mismatch" + ); + assert!(rows.num_rows() > 0); + Self { rows, reservation } } +} - #[inline] - fn advance(&mut self) -> usize { - let t = self.cur_row; - self.cur_row += 1; - t +impl CursorValues for RowValues { + fn len(&self) -> usize { + self.rows.num_rows() } -} -/// An [`Array`] that can be converted into [`FieldValues`] -pub trait FieldArray: Array + 'static { - type Values: FieldValues; + fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool { + l.rows.row(l_idx) == r.rows.row(r_idx) + } - fn values(&self) -> Self::Values; + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { + l.rows.row(l_idx).cmp(&r.rows.row(r_idx)) + } } -/// A comparable set of non-nullable values -pub trait FieldValues { - type Value: ?Sized; - - fn len(&self) -> usize; - - fn compare(a: &Self::Value, b: &Self::Value) -> Ordering; +/// An [`Array`] that can be converted into [`CursorValues`] +pub trait CursorArray: Array + 'static { + type Values: CursorValues; - fn value(&self, idx: usize) -> &Self::Value; + fn values(&self) -> Self::Values; } -impl FieldArray for PrimitiveArray { +impl CursorArray for PrimitiveArray { type Values = PrimitiveValues; fn values(&self) -> Self::Values { @@ -144,71 +179,81 @@ impl FieldArray for PrimitiveArray { #[derive(Debug)] pub struct PrimitiveValues(ScalarBuffer); -impl FieldValues for PrimitiveValues { - type Value = T; - +impl CursorValues for PrimitiveValues { fn len(&self) -> usize { self.0.len() } - #[inline] - fn compare(a: &Self::Value, b: &Self::Value) -> Ordering { - T::compare(*a, *b) + fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool { + l.0[l_idx].is_eq(r.0[r_idx]) } - #[inline] - fn value(&self, idx: usize) -> &Self::Value { - &self.0[idx] + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { + l.0[l_idx].compare(r.0[r_idx]) } } -impl FieldArray for GenericByteArray { - type Values = Self; +pub struct ByteArrayValues { + offsets: OffsetBuffer, + values: Buffer, +} - fn values(&self) -> Self::Values { - // Once https://github.com/apache/arrow-rs/pull/4048 is released - // Could potentially destructure array into buffers to reduce codegen, - // in a similar vein to what is done for PrimitiveArray - self.clone() +impl ByteArrayValues { + fn value(&self, idx: usize) -> &[u8] { + assert!(idx < self.len()); + // Safety: offsets are valid and checked bounds above + unsafe { + let start = self.offsets.get_unchecked(idx).as_usize(); + let end = self.offsets.get_unchecked(idx + 1).as_usize(); + self.values.get_unchecked(start..end) + } } } -impl FieldValues for GenericByteArray { - type Value = T::Native; - +impl CursorValues for ByteArrayValues { fn len(&self) -> usize { - Array::len(self) + self.offsets.len() - 1 + } + + fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool { + l.value(l_idx) == r.value(r_idx) } - #[inline] - fn compare(a: &Self::Value, b: &Self::Value) -> Ordering { - let a: &[u8] = a.as_ref(); - let b: &[u8] = b.as_ref(); - a.cmp(b) + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { + l.value(l_idx).cmp(r.value(r_idx)) } +} + +impl CursorArray for GenericByteArray { + type Values = ByteArrayValues; - #[inline] - fn value(&self, idx: usize) -> &Self::Value { - self.value(idx) + fn values(&self) -> Self::Values { + ByteArrayValues { + offsets: self.offsets().clone(), + values: self.values().clone(), + } } } -/// A cursor over sorted, nullable [`FieldValues`] +/// A collection of sorted, nullable [`CursorValues`] /// /// Note: comparing cursors with different `SortOptions` will yield an arbitrary ordering #[derive(Debug)] -pub struct FieldCursor { +pub struct ArrayValues { values: T, - offset: usize, // If nulls first, the first non-null index // Otherwise, the first null index null_threshold: usize, options: SortOptions, } -impl FieldCursor { - /// Create a new [`FieldCursor`] from the provided `values` sorted according to `options` - pub fn new>(options: SortOptions, array: &A) -> Self { +impl ArrayValues { + /// Create a new [`ArrayValues`] from the provided `values` sorted according + /// to `options`. + /// + /// Panics if the array is empty + pub fn new>(options: SortOptions, array: &A) -> Self { + assert!(array.len() > 0, "Empty array passed to FieldCursor"); let null_threshold = match options.nulls_first { true => array.null_count(), false => array.len() - array.null_count(), @@ -216,67 +261,48 @@ impl FieldCursor { Self { values: array.values(), - offset: 0, null_threshold, options, } } - fn is_null(&self) -> bool { - (self.offset < self.null_threshold) == self.options.nulls_first + fn is_null(&self, idx: usize) -> bool { + (idx < self.null_threshold) == self.options.nulls_first } } -impl PartialEq for FieldCursor { - fn eq(&self, other: &Self) -> bool { - self.cmp(other).is_eq() +impl CursorValues for ArrayValues { + fn len(&self) -> usize { + self.values.len() } -} -impl Eq for FieldCursor {} -impl PartialOrd for FieldCursor { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) + fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool { + match (l.is_null(l_idx), r.is_null(r_idx)) { + (true, true) => true, + (false, false) => T::eq(&l.values, l_idx, &r.values, r_idx), + _ => false, + } } -} -impl Ord for FieldCursor { - fn cmp(&self, other: &Self) -> Ordering { - match (self.is_null(), other.is_null()) { + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { + match (l.is_null(l_idx), r.is_null(r_idx)) { (true, true) => Ordering::Equal, - (true, false) => match self.options.nulls_first { + (true, false) => match l.options.nulls_first { true => Ordering::Less, false => Ordering::Greater, }, - (false, true) => match self.options.nulls_first { + (false, true) => match l.options.nulls_first { true => Ordering::Greater, false => Ordering::Less, }, - (false, false) => { - let s_v = self.values.value(self.offset); - let o_v = other.values.value(other.offset); - - match self.options.descending { - true => T::compare(o_v, s_v), - false => T::compare(s_v, o_v), - } - } + (false, false) => match l.options.descending { + true => T::compare(&r.values, r_idx, &l.values, l_idx), + false => T::compare(&l.values, l_idx, &r.values, r_idx), + }, } } } -impl Cursor for FieldCursor { - fn is_finished(&self) -> bool { - self.offset == self.values.len() - } - - fn advance(&mut self) -> usize { - let t = self.offset; - self.offset += 1; - t - } -} - #[cfg(test)] mod tests { use super::*; @@ -285,18 +311,19 @@ mod tests { options: SortOptions, values: ScalarBuffer, null_count: usize, - ) -> FieldCursor> { + ) -> Cursor>> { let null_threshold = match options.nulls_first { true => null_count, false => values.len() - null_count, }; - FieldCursor { - offset: 0, + let values = ArrayValues { values: PrimitiveValues(values), null_threshold, options, - } + }; + + Cursor::new(values) } #[test] diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 67685509abe5..422ff3aebdb3 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -20,85 +20,22 @@ use crate::metrics::BaselineMetrics; use crate::sorts::builder::BatchBuilder; -use crate::sorts::cursor::Cursor; -use crate::sorts::stream::{FieldCursorStream, PartitionedStream, RowCursorStream}; -use crate::{PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream}; -use arrow::datatypes::{DataType, SchemaRef}; +use crate::sorts::cursor::{Cursor, CursorValues}; +use crate::sorts::stream::PartitionedStream; +use crate::RecordBatchStream; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use arrow_array::*; use datafusion_common::Result; use datafusion_execution::memory_pool::MemoryReservation; use futures::Stream; use std::pin::Pin; use std::task::{ready, Context, Poll}; -macro_rules! primitive_merge_helper { - ($t:ty, $($v:ident),+) => { - merge_helper!(PrimitiveArray<$t>, $($v),+) - }; -} - -macro_rules! merge_helper { - ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident, $reservation:ident) => {{ - let streams = FieldCursorStream::<$t>::new($sort, $streams); - return Ok(Box::pin(SortPreservingMergeStream::new( - Box::new(streams), - $schema, - $tracking_metrics, - $batch_size, - $fetch, - $reservation, - ))); - }}; -} - -/// Perform a streaming merge of [`SendableRecordBatchStream`] based on provided sort expressions -/// while preserving order. -pub fn streaming_merge( - streams: Vec, - schema: SchemaRef, - expressions: &[PhysicalSortExpr], - metrics: BaselineMetrics, - batch_size: usize, - fetch: Option, - reservation: MemoryReservation, -) -> Result { - // Special case single column comparisons with optimized cursor implementations - if expressions.len() == 1 { - let sort = expressions[0].clone(); - let data_type = sort.expr.data_type(schema.as_ref())?; - downcast_primitive! { - data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation), - DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - _ => {} - } - } - - let streams = RowCursorStream::try_new( - schema.as_ref(), - expressions, - streams, - reservation.new_empty(), - )?; - - Ok(Box::pin(SortPreservingMergeStream::new( - Box::new(streams), - schema, - metrics, - batch_size, - fetch, - reservation, - ))) -} - /// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`] type CursorStream = Box>>; #[derive(Debug)] -struct SortPreservingMergeStream { +pub(crate) struct SortPreservingMergeStream { in_progress: BatchBuilder, /// The sorted input streams to merge together @@ -151,8 +88,8 @@ struct SortPreservingMergeStream { /// target batch size batch_size: usize, - /// Vector that holds cursors for each non-exhausted input partition - cursors: Vec>, + /// Cursors for each input partition. `None` means the input is exhausted + cursors: Vec>>, /// Optional number of rows to fetch fetch: Option, @@ -161,8 +98,8 @@ struct SortPreservingMergeStream { produced: usize, } -impl SortPreservingMergeStream { - fn new( +impl SortPreservingMergeStream { + pub(crate) fn new( streams: CursorStream, schema: SchemaRef, metrics: BaselineMetrics, @@ -203,7 +140,7 @@ impl SortPreservingMergeStream { None => Poll::Ready(Ok(())), Some(Err(e)) => Poll::Ready(Err(e)), Some(Ok((cursor, batch))) => { - self.cursors[idx] = Some(cursor); + self.cursors[idx] = Some(Cursor::new(cursor)); Poll::Ready(self.in_progress.push_batch(idx, batch)) } } @@ -373,7 +310,7 @@ impl SortPreservingMergeStream { } } -impl Stream for SortPreservingMergeStream { +impl Stream for SortPreservingMergeStream { type Item = Result; fn poll_next( @@ -385,7 +322,7 @@ impl Stream for SortPreservingMergeStream { } } -impl RecordBatchStream for SortPreservingMergeStream { +impl RecordBatchStream for SortPreservingMergeStream { fn schema(&self) -> SchemaRef { self.in_progress.schema().clone() } diff --git a/datafusion/physical-plan/src/sorts/mod.rs b/datafusion/physical-plan/src/sorts/mod.rs index dff39db423f0..8a1184d3c2b5 100644 --- a/datafusion/physical-plan/src/sorts/mod.rs +++ b/datafusion/physical-plan/src/sorts/mod.rs @@ -20,10 +20,11 @@ mod builder; mod cursor; mod index; -pub mod merge; +mod merge; pub mod sort; pub mod sort_preserving_merge; mod stream; +pub mod streaming_merge; pub use index::RowIndex; -pub(crate) use merge::streaming_merge; +pub(crate) use streaming_merge::streaming_merge; diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 7d260d42d9cd..2d8237011fff 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -19,18 +19,27 @@ //! It will do in-memory sorting if it has enough memory budget //! but spills to disk if needed. +use std::any::Any; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::fs::File; +use std::io::BufReader; +use std::path::{Path, PathBuf}; +use std::sync::Arc; + use crate::common::{spawn_buffered, IPCWriter}; use crate::expressions::PhysicalSortExpr; use crate::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, }; -use crate::sorts::merge::streaming_merge; +use crate::sorts::streaming_merge::streaming_merge; use crate::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; +use crate::topk::TopK; use crate::{ DisplayAs, DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; -pub use arrow::compute::SortOptions; + use arrow::compute::{concat_batches, lexsort_to_indices, take}; use arrow::datatypes::SchemaRef; use arrow::ipc::reader::FileReader; @@ -43,15 +52,9 @@ use datafusion_execution::memory_pool::{ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; + use futures::{StreamExt, TryStreamExt}; use log::{debug, error, trace}; -use std::any::Any; -use std::fmt; -use std::fmt::{Debug, Formatter}; -use std::fs::File; -use std::io::BufReader; -use std::path::{Path, PathBuf}; -use std::sync::Arc; use tokio::sync::mpsc::Sender; use tokio::task; @@ -732,7 +735,13 @@ impl SortExec { self } - /// Whether this `SortExec` preserves partitioning of the children + /// Modify how many rows to include in the result + /// + /// If None, then all rows will be returned, in sorted order. + /// If Some, then only the top `fetch` rows will be returned. + /// This can reduce the memory pressure required by the sort + /// operation since rows that are not going to be included + /// can be dropped. pub fn with_fetch(mut self, fetch: Option) -> Self { self.fetch = fetch; self @@ -762,12 +771,12 @@ impl DisplayAs for SortExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - let expr: Vec = self.expr.iter().map(|e| e.to_string()).collect(); + let expr = PhysicalSortExpr::format_list(&self.expr); match self.fetch { Some(fetch) => { - write!(f, "SortExec: fetch={fetch}, expr=[{}]", expr.join(",")) + write!(f, "SortExec: TopK(fetch={fetch}), expr=[{expr}]",) } - None => write!(f, "SortExec: expr=[{}]", expr.join(",")), + None => write!(f, "SortExec: expr=[{expr}]"), } } } @@ -826,7 +835,10 @@ impl ExecutionPlan for SortExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - self.input.equivalence_properties() + // Reset the ordering equivalence class with the new ordering: + self.input + .equivalence_properties() + .with_reorder(self.expr.to_vec()) } fn with_new_children( @@ -853,42 +865,69 @@ impl ExecutionPlan for SortExec { trace!("End SortExec's input.execute for partition: {}", partition); - let mut sorter = ExternalSorter::new( - partition, - input.schema(), - self.expr.clone(), - context.session_config().batch_size(), - self.fetch, - execution_options.sort_spill_reservation_bytes, - execution_options.sort_in_place_threshold_bytes, - &self.metrics_set, - context.runtime_env(), - ); + if let Some(fetch) = self.fetch.as_ref() { + let mut topk = TopK::try_new( + partition, + input.schema(), + self.expr.clone(), + *fetch, + context.session_config().batch_size(), + context.runtime_env(), + &self.metrics_set, + partition, + )?; + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::once(async move { + while let Some(batch) = input.next().await { + let batch = batch?; + topk.insert_batch(batch)?; + } + topk.emit() + }) + .try_flatten(), + ))) + } else { + let mut sorter = ExternalSorter::new( + partition, + input.schema(), + self.expr.clone(), + context.session_config().batch_size(), + self.fetch, + execution_options.sort_spill_reservation_bytes, + execution_options.sort_in_place_threshold_bytes, + &self.metrics_set, + context.runtime_env(), + ); - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema(), - futures::stream::once(async move { - while let Some(batch) = input.next().await { - let batch = batch?; - sorter.insert_batch(batch).await?; - } - sorter.sort() - }) - .try_flatten(), - ))) + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::once(async move { + while let Some(batch) = input.next().await { + let batch = batch?; + sorter.insert_batch(batch).await?; + } + sorter.sort() + }) + .try_flatten(), + ))) + } } fn metrics(&self) -> Option { Some(self.metrics_set.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { self.input.statistics() } } #[cfg(test)] mod tests { + use std::collections::HashMap; + use super::*; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::collect; @@ -897,14 +936,15 @@ mod tests { use crate::test; use crate::test::assert_is_pending; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + use arrow::array::*; use arrow::compute::SortOptions; use arrow::datatypes::*; use datafusion_common::cast::as_primitive_array; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeConfig; + use futures::FutureExt; - use std::collections::HashMap; #[tokio::test] async fn test_in_mem_sort() -> Result<()> { @@ -1043,7 +1083,7 @@ mod tests { assert_eq!(result.len(), 1); let metrics = sort_exec.metrics().unwrap(); - let did_it_spill = metrics.spill_count().unwrap() > 0; + let did_it_spill = metrics.spill_count().unwrap_or(0) > 0; assert_eq!(did_it_spill, expect_spillage, "with fetch: {fetch:?}"); } Ok(()) diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 5b485e0b68e4..f4b57e8bfb45 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -28,14 +28,12 @@ use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; -use datafusion_execution::memory_pool::MemoryConsumer; use arrow::datatypes::SchemaRef; use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{ - EquivalenceProperties, OrderingEquivalenceProperties, PhysicalSortRequirement, -}; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement}; use log::{debug, trace}; @@ -118,8 +116,11 @@ impl DisplayAs for SortPreservingMergeExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - let expr: Vec = self.expr.iter().map(|e| e.to_string()).collect(); - write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))?; + write!( + f, + "SortPreservingMergeExec: [{}]", + PhysicalSortExpr::format_list(&self.expr) + )?; if let Some(fetch) = self.fetch { write!(f, ", fetch={fetch}")?; }; @@ -176,10 +177,6 @@ impl ExecutionPlan for SortPreservingMergeExec { self.input.equivalence_properties() } - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - self.input.ordering_equivalence_properties() - } - fn children(&self) -> Vec> { vec![self.input.clone()] } @@ -261,7 +258,7 @@ impl ExecutionPlan for SortPreservingMergeExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { self.input.statistics() } } @@ -270,13 +267,7 @@ impl ExecutionPlan for SortPreservingMergeExec { mod tests { use std::iter::FromIterator; - use arrow::array::ArrayRef; - use arrow::compute::SortOptions; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion_execution::config::SessionConfig; - use futures::{FutureExt, StreamExt}; - + use super::*; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::expressions::col; use crate::memory::MemoryExec; @@ -286,10 +277,15 @@ mod tests { use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::{self, assert_is_pending, make_partition}; use crate::{collect, common}; - use arrow::array::{Int32Array, StringArray, TimestampNanosecondArray}; - use datafusion_common::assert_batches_eq; - use super::*; + use arrow::array::{ArrayRef, Int32Array, StringArray, TimestampNanosecondArray}; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_common::{assert_batches_eq, assert_contains}; + use datafusion_execution::config::SessionConfig; + + use futures::{FutureExt, StreamExt}; #[tokio::test] async fn test_merge_interleave() { @@ -339,6 +335,25 @@ mod tests { .await; } + #[tokio::test] + async fn test_merge_no_exprs() { + let task_ctx = Arc::new(TaskContext::default()); + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); + let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap(); + + let schema = batch.schema(); + let sort = vec![]; // no sort expressions + let exec = MemoryExec::try_new(&[vec![batch.clone()], vec![batch]], schema, None) + .unwrap(); + let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); + + let res = collect(merge, task_ctx).await.unwrap_err(); + assert_contains!( + res.to_string(), + "Internal error: Sort expressions cannot be empty for streaming merge" + ); + } + #[tokio::test] async fn test_merge_some_overlap() { let task_ctx = Arc::new(TaskContext::default()); diff --git a/datafusion/physical-plan/src/sorts/stream.rs b/datafusion/physical-plan/src/sorts/stream.rs index a7f9e7380c47..135b4fbdece4 100644 --- a/datafusion/physical-plan/src/sorts/stream.rs +++ b/datafusion/physical-plan/src/sorts/stream.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::sorts::cursor::{FieldArray, FieldCursor, RowCursor}; +use crate::sorts::cursor::{ArrayValues, CursorArray, RowValues}; use crate::SendableRecordBatchStream; use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::array::Array; @@ -76,7 +76,7 @@ impl FusedStreams { } /// A [`PartitionedStream`] that wraps a set of [`SendableRecordBatchStream`] -/// and computes [`RowCursor`] based on the provided [`PhysicalSortExpr`] +/// and computes [`RowValues`] based on the provided [`PhysicalSortExpr`] #[derive(Debug)] pub struct RowCursorStream { /// Converter to convert output of physical expressions @@ -114,11 +114,11 @@ impl RowCursorStream { }) } - fn convert_batch(&mut self, batch: &RecordBatch) -> Result { + fn convert_batch(&mut self, batch: &RecordBatch) -> Result { let cols = self .column_expressions .iter() - .map(|expr| Ok(expr.evaluate(batch)?.into_array(batch.num_rows()))) + .map(|expr| expr.evaluate(batch)?.into_array(batch.num_rows())) .collect::>>()?; let rows = self.converter.convert_columns(&cols)?; @@ -127,12 +127,12 @@ impl RowCursorStream { // track the memory in the newly created Rows. let mut rows_reservation = self.reservation.new_empty(); rows_reservation.try_grow(rows.size())?; - Ok(RowCursor::new(rows, rows_reservation)) + Ok(RowValues::new(rows, rows_reservation)) } } impl PartitionedStream for RowCursorStream { - type Output = Result<(RowCursor, RecordBatch)>; + type Output = Result<(RowValues, RecordBatch)>; fn partitions(&self) -> usize { self.streams.0.len() @@ -153,7 +153,7 @@ impl PartitionedStream for RowCursorStream { } /// Specialized stream for sorts on single primitive columns -pub struct FieldCursorStream { +pub struct FieldCursorStream { /// The physical expressions to sort by sort: PhysicalSortExpr, /// Input streams @@ -161,7 +161,7 @@ pub struct FieldCursorStream { phantom: PhantomData T>, } -impl std::fmt::Debug for FieldCursorStream { +impl std::fmt::Debug for FieldCursorStream { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PrimitiveCursorStream") .field("num_streams", &self.streams) @@ -169,7 +169,7 @@ impl std::fmt::Debug for FieldCursorStream { } } -impl FieldCursorStream { +impl FieldCursorStream { pub fn new(sort: PhysicalSortExpr, streams: Vec) -> Self { let streams = streams.into_iter().map(|s| s.fuse()).collect(); Self { @@ -179,16 +179,16 @@ impl FieldCursorStream { } } - fn convert_batch(&mut self, batch: &RecordBatch) -> Result> { + fn convert_batch(&mut self, batch: &RecordBatch) -> Result> { let value = self.sort.expr.evaluate(batch)?; - let array = value.into_array(batch.num_rows()); + let array = value.into_array(batch.num_rows())?; let array = array.as_any().downcast_ref::().expect("field values"); - Ok(FieldCursor::new(self.sort.options, array)) + Ok(ArrayValues::new(self.sort.options, array)) } } -impl PartitionedStream for FieldCursorStream { - type Output = Result<(FieldCursor, RecordBatch)>; +impl PartitionedStream for FieldCursorStream { + type Output = Result<(ArrayValues, RecordBatch)>; fn partitions(&self) -> usize { self.streams.0.len() diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs new file mode 100644 index 000000000000..4f8d8063853b --- /dev/null +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Merge that deals with an arbitrary size of streaming inputs. +//! This is an order-preserving merge. + +use crate::metrics::BaselineMetrics; +use crate::sorts::{ + merge::SortPreservingMergeStream, + stream::{FieldCursorStream, RowCursorStream}, +}; +use crate::{PhysicalSortExpr, SendableRecordBatchStream}; +use arrow::datatypes::{DataType, SchemaRef}; +use arrow_array::*; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_execution::memory_pool::MemoryReservation; + +macro_rules! primitive_merge_helper { + ($t:ty, $($v:ident),+) => { + merge_helper!(PrimitiveArray<$t>, $($v),+) + }; +} + +macro_rules! merge_helper { + ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident, $reservation:ident) => {{ + let streams = FieldCursorStream::<$t>::new($sort, $streams); + return Ok(Box::pin(SortPreservingMergeStream::new( + Box::new(streams), + $schema, + $tracking_metrics, + $batch_size, + $fetch, + $reservation, + ))); + }}; +} + +/// Perform a streaming merge of [`SendableRecordBatchStream`] based on provided sort expressions +/// while preserving order. +pub fn streaming_merge( + streams: Vec, + schema: SchemaRef, + expressions: &[PhysicalSortExpr], + metrics: BaselineMetrics, + batch_size: usize, + fetch: Option, + reservation: MemoryReservation, +) -> Result { + // If there are no sort expressions, preserving the order + // doesn't mean anything (and result in infinite loops) + if expressions.is_empty() { + return internal_err!("Sort expressions cannot be empty for streaming merge"); + } + // Special case single column comparisons with optimized cursor implementations + if expressions.len() == 1 { + let sort = expressions[0].clone(); + let data_type = sort.expr.data_type(schema.as_ref())?; + downcast_primitive! { + data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation), + DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + _ => {} + } + } + + let streams = RowCursorStream::try_new( + schema.as_ref(), + expressions, + streams, + reservation.new_empty(), + )?; + + Ok(Box::pin(SortPreservingMergeStream::new( + Box::new(streams), + schema, + metrics, + batch_size, + fetch, + reservation, + ))) +} diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index a3fb856c326d..fdf32620ca50 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -38,6 +38,124 @@ use tokio::task::JoinSet; use super::metrics::BaselineMetrics; use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; +/// Creates a stream from a collection of producing tasks, routing panics to the stream. +/// +/// Note that this is similar to [`ReceiverStream` from tokio-stream], with the differences being: +/// +/// 1. Methods to bound and "detach" tasks (`spawn()` and `spawn_blocking()`). +/// +/// 2. Propagates panics, whereas the `tokio` version doesn't propagate panics to the receiver. +/// +/// 3. Automatically cancels any outstanding tasks when the receiver stream is dropped. +/// +/// [`ReceiverStream` from tokio-stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html + +pub(crate) struct ReceiverStreamBuilder { + tx: Sender>, + rx: Receiver>, + join_set: JoinSet>, +} + +impl ReceiverStreamBuilder { + /// create new channels with the specified buffer size + pub fn new(capacity: usize) -> Self { + let (tx, rx) = tokio::sync::mpsc::channel(capacity); + + Self { + tx, + rx, + join_set: JoinSet::new(), + } + } + + /// Get a handle for sending data to the output + pub fn tx(&self) -> Sender> { + self.tx.clone() + } + + /// Spawn task that will be aborted if this builder (or the stream + /// built from it) are dropped + pub fn spawn(&mut self, task: F) + where + F: Future>, + F: Send + 'static, + { + self.join_set.spawn(task); + } + + /// Spawn a blocking task that will be aborted if this builder (or the stream + /// built from it) are dropped + /// + /// this is often used to spawn tasks that write to the sender + /// retrieved from `Self::tx` + pub fn spawn_blocking(&mut self, f: F) + where + F: FnOnce() -> Result<()>, + F: Send + 'static, + { + self.join_set.spawn_blocking(f); + } + + /// Create a stream of all data written to `tx` + pub fn build(self) -> BoxStream<'static, Result> { + let Self { + tx, + rx, + mut join_set, + } = self; + + // don't need tx + drop(tx); + + // future that checks the result of the join set, and propagates panic if seen + let check = async move { + while let Some(result) = join_set.join_next().await { + match result { + Ok(task_result) => { + match task_result { + // nothing to report + Ok(_) => continue, + // This means a blocking task error + Err(e) => { + return Some(exec_err!("Spawned Task error: {e}")); + } + } + } + // This means a tokio task error, likely a panic + Err(e) => { + if e.is_panic() { + // resume on the main thread + std::panic::resume_unwind(e.into_panic()); + } else { + // This should only occur if the task is + // cancelled, which would only occur if + // the JoinSet were aborted, which in turn + // would imply that the receiver has been + // dropped and this code is not running + return Some(internal_err!("Non Panic Task error: {e}")); + } + } + } + } + None + }; + + let check_stream = futures::stream::once(check) + // unwrap Option / only return the error + .filter_map(|item| async move { item }); + + // Convert the receiver into a stream + let rx_stream = futures::stream::unfold(rx, |mut rx| async move { + let next_item = rx.recv().await; + next_item.map(|next_item| (next_item, rx)) + }); + + // Merge the streams together so whichever is ready first + // produces the batch + futures::stream::select(rx_stream, check_stream).boxed() + } +} + /// Builder for [`RecordBatchReceiverStream`] that propagates errors /// and panic's correctly. /// @@ -47,28 +165,22 @@ use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; /// /// This also handles propagating panic`s and canceling the tasks. pub struct RecordBatchReceiverStreamBuilder { - tx: Sender>, - rx: Receiver>, schema: SchemaRef, - join_set: JoinSet>, + inner: ReceiverStreamBuilder, } impl RecordBatchReceiverStreamBuilder { /// create new channels with the specified buffer size pub fn new(schema: SchemaRef, capacity: usize) -> Self { - let (tx, rx) = tokio::sync::mpsc::channel(capacity); - Self { - tx, - rx, schema, - join_set: JoinSet::new(), + inner: ReceiverStreamBuilder::new(capacity), } } - /// Get a handle for sending [`RecordBatch`]es to the output + /// Get a handle for sending [`RecordBatch`] to the output pub fn tx(&self) -> Sender> { - self.tx.clone() + self.inner.tx() } /// Spawn task that will be aborted if this builder (or the stream @@ -81,7 +193,7 @@ impl RecordBatchReceiverStreamBuilder { F: Future>, F: Send + 'static, { - self.join_set.spawn(task); + self.inner.spawn(task) } /// Spawn a blocking task that will be aborted if this builder (or the stream @@ -94,7 +206,7 @@ impl RecordBatchReceiverStreamBuilder { F: FnOnce() -> Result<()>, F: Send + 'static, { - self.join_set.spawn_blocking(f); + self.inner.spawn_blocking(f) } /// runs the input_partition of the `input` ExecutionPlan on the @@ -110,7 +222,7 @@ impl RecordBatchReceiverStreamBuilder { ) { let output = self.tx(); - self.spawn(async move { + self.inner.spawn(async move { let mut stream = match input.execute(partition, context) { Err(e) => { // If send fails, the plan being torn down, there @@ -155,80 +267,17 @@ impl RecordBatchReceiverStreamBuilder { }); } - /// Create a stream of all `RecordBatch`es written to `tx` + /// Create a stream of all [`RecordBatch`] written to `tx` pub fn build(self) -> SendableRecordBatchStream { - let Self { - tx, - rx, - schema, - mut join_set, - } = self; - - // don't need tx - drop(tx); - - // future that checks the result of the join set, and propagates panic if seen - let check = async move { - while let Some(result) = join_set.join_next().await { - match result { - Ok(task_result) => { - match task_result { - // nothing to report - Ok(_) => continue, - // This means a blocking task error - Err(e) => { - return Some(exec_err!("Spawned Task error: {e}")); - } - } - } - // This means a tokio task error, likely a panic - Err(e) => { - if e.is_panic() { - // resume on the main thread - std::panic::resume_unwind(e.into_panic()); - } else { - // This should only occur if the task is - // cancelled, which would only occur if - // the JoinSet were aborted, which in turn - // would imply that the receiver has been - // dropped and this code is not running - return Some(internal_err!("Non Panic Task error: {e}")); - } - } - } - } - None - }; - - let check_stream = futures::stream::once(check) - // unwrap Option / only return the error - .filter_map(|item| async move { item }); - - // Convert the receiver into a stream - let rx_stream = futures::stream::unfold(rx, |mut rx| async move { - let next_item = rx.recv().await; - next_item.map(|next_item| (next_item, rx)) - }); - - // Merge the streams together so whichever is ready first - // produces the batch - let inner = futures::stream::select(rx_stream, check_stream).boxed(); - - Box::pin(RecordBatchReceiverStream { schema, inner }) + Box::pin(RecordBatchStreamAdapter::new( + self.schema, + self.inner.build(), + )) } } -/// A [`SendableRecordBatchStream`] that combines [`RecordBatch`]es from multiple inputs, -/// on new tokio Tasks, increasing the potential parallelism. -/// -/// This structure also handles propagating panics and cancelling the -/// underlying tasks correctly. -/// -/// Use [`Self::builder`] to construct one. -pub struct RecordBatchReceiverStream { - schema: SchemaRef, - inner: BoxStream<'static, Result>, -} +#[doc(hidden)] +pub struct RecordBatchReceiverStream {} impl RecordBatchReceiverStream { /// Create a builder with an internal buffer of capacity batches. @@ -240,23 +289,6 @@ impl RecordBatchReceiverStream { } } -impl Stream for RecordBatchReceiverStream { - type Item = Result; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.inner.poll_next_unpin(cx) - } -} - -impl RecordBatchStream for RecordBatchReceiverStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - pin_project! { /// Combines a [`Stream`] with a [`SchemaRef`] implementing /// [`RecordBatchStream`] for the combination diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index 00809b71e443..59819c6921fb 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -15,27 +15,31 @@ // specific language governing permissions and limitations // under the License. -//! Execution plan for streaming [`PartitionStream`] +//! Generic plans for deferred execution: [`StreamingTableExec`] and [`PartitionStream`] use std::any::Any; use std::sync::Arc; -use arrow::datatypes::SchemaRef; -use async_trait::async_trait; -use futures::stream::StreamExt; - -use datafusion_common::{internal_err, plan_err, DataFusionError, Result, Statistics}; -use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; -use log::debug; - +use super::{DisplayAs, DisplayFormatType}; use crate::display::{OutputOrderingDisplay, ProjectSchemaDisplay}; use crate::stream::RecordBatchStreamAdapter; use crate::{ExecutionPlan, Partitioning, SendableRecordBatchStream}; + +use arrow::datatypes::SchemaRef; +use arrow_schema::Schema; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalSortExpr}; -use super::{DisplayAs, DisplayFormatType}; +use async_trait::async_trait; +use futures::stream::StreamExt; +use log::debug; /// A partition that can be converted into a [`SendableRecordBatchStream`] +/// +/// Combined with [`StreamingTableExec`], you can use this trait to implement +/// [`ExecutionPlan`] for a custom source with less boiler plate than +/// implementing `ExecutionPlan` directly for many use cases. pub trait PartitionStream: Send + Sync { /// Returns the schema of this partition fn schema(&self) -> &SchemaRef; @@ -44,12 +48,15 @@ pub trait PartitionStream: Send + Sync { fn execute(&self, ctx: Arc) -> SendableRecordBatchStream; } -/// An [`ExecutionPlan`] for [`PartitionStream`] +/// An [`ExecutionPlan`] for one or more [`PartitionStream`]s. +/// +/// If your source can be represented as one or more [`PartitionStream`]s, you can +/// use this struct to implement [`ExecutionPlan`]. pub struct StreamingTableExec { partitions: Vec>, projection: Option>, projected_schema: SchemaRef, - projected_output_ordering: Option, + projected_output_ordering: Vec, infinite: bool, } @@ -59,14 +66,14 @@ impl StreamingTableExec { schema: SchemaRef, partitions: Vec>, projection: Option<&Vec>, - projected_output_ordering: Option, + projected_output_ordering: impl IntoIterator, infinite: bool, ) -> Result { for x in partitions.iter() { let partition_schema = x.schema(); - if !schema.contains(partition_schema) { + if !schema.eq(partition_schema) { debug!( - "target schema does not contain partition schema. \ + "Target schema does not match with partition schema. \ Target_schema: {schema:?}. Partiton Schema: {partition_schema:?}" ); return plan_err!("Mismatch between schema and batches"); @@ -82,10 +89,34 @@ impl StreamingTableExec { partitions, projected_schema, projection: projection.cloned().map(Into::into), - projected_output_ordering, + projected_output_ordering: projected_output_ordering.into_iter().collect(), infinite, }) } + + pub fn partitions(&self) -> &Vec> { + &self.partitions + } + + pub fn partition_schema(&self) -> &SchemaRef { + self.partitions[0].schema() + } + + pub fn projection(&self) -> &Option> { + &self.projection + } + + pub fn projected_schema(&self) -> &Schema { + &self.projected_schema + } + + pub fn projected_output_ordering(&self) -> impl IntoIterator { + self.projected_output_ordering.clone() + } + + pub fn is_infinite(&self) -> bool { + self.infinite + } } impl std::fmt::Debug for StreamingTableExec { @@ -119,7 +150,7 @@ impl DisplayAs for StreamingTableExec { } self.projected_output_ordering - .as_deref() + .first() .map_or(Ok(()), |ordering| { if !ordering.is_empty() { write!( @@ -154,7 +185,16 @@ impl ExecutionPlan for StreamingTableExec { } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.projected_output_ordering.as_deref() + self.projected_output_ordering + .first() + .map(|ordering| ordering.as_slice()) + } + + fn equivalence_properties(&self) -> EquivalenceProperties { + EquivalenceProperties::new_with_orderings( + self.schema(), + &self.projected_output_ordering, + ) } fn children(&self) -> Vec> { @@ -163,9 +203,13 @@ impl ExecutionPlan for StreamingTableExec { fn with_new_children( self: Arc, - _children: Vec>, + children: Vec>, ) -> Result> { - internal_err!("Children cannot be replaced in {self:?}") + if children.is_empty() { + Ok(self) + } else { + internal_err!("Children cannot be replaced in {self:?}") + } } fn execute( @@ -184,8 +228,4 @@ impl ExecutionPlan for StreamingTableExec { None => stream, }) } - - fn statistics(&self) -> Statistics { - Default::default() - } } diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index a1f40c7ba909..1f6ee1f117aa 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -23,23 +23,21 @@ use std::{ sync::{Arc, Weak}, task::{Context, Poll}, }; -use tokio::sync::Barrier; - -use arrow::{ - datatypes::{DataType, Field, Schema, SchemaRef}, - record_batch::RecordBatch, -}; -use futures::Stream; +use crate::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; use crate::{ - common, stream::RecordBatchReceiverStream, stream::RecordBatchStreamAdapter, - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + common, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use datafusion_physical_expr::PhysicalSortExpr; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::PhysicalSortExpr; + +use futures::Stream; +use tokio::sync::Barrier; /// Index into the data that has been returned so far #[derive(Debug, Default, Clone)] @@ -239,7 +237,7 @@ impl ExecutionPlan for MockExec { } // Panics if one of the batches is an error - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { let data: Result> = self .data .iter() @@ -249,9 +247,13 @@ impl ExecutionPlan for MockExec { }) .collect(); - let data = data.unwrap(); + let data = data?; - common::compute_record_batch_statistics(&[data], &self.schema, None) + Ok(common::compute_record_batch_statistics( + &[data], + &self.schema, + None, + )) } } @@ -369,8 +371,12 @@ impl ExecutionPlan for BarrierExec { Ok(builder.build()) } - fn statistics(&self) -> Statistics { - common::compute_record_batch_statistics(&self.data, &self.schema, None) + fn statistics(&self) -> Result { + Ok(common::compute_record_batch_statistics( + &self.data, + &self.schema, + None, + )) } } @@ -447,10 +453,6 @@ impl ExecutionPlan for ErrorExec { ) -> Result { internal_err!("ErrorExec, unsurprisingly, errored in partition {partition}") } - - fn statistics(&self) -> Statistics { - Statistics::default() - } } /// A mock execution plan that simply returns the provided statistics @@ -461,12 +463,9 @@ pub struct StatisticsExec { } impl StatisticsExec { pub fn new(stats: Statistics, schema: Schema) -> Self { - assert!( + assert_eq!( stats - .column_statistics - .as_ref() - .map(|cols| cols.len() == schema.fields().len()) - .unwrap_or(true), + .column_statistics.len(), schema.fields().len(), "if defined, the column statistics vector length should be the number of fields" ); Self { @@ -531,8 +530,8 @@ impl ExecutionPlan for StatisticsExec { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Statistics { - self.stats.clone() + fn statistics(&self) -> Result { + Ok(self.stats.clone()) } } @@ -624,10 +623,6 @@ impl ExecutionPlan for BlockingExec { _refs: Arc::clone(&self.refs), })) } - - fn statistics(&self) -> Statistics { - unimplemented!() - } } /// A [`RecordBatchStream`] that is pending forever. @@ -761,10 +756,6 @@ impl ExecutionPlan for PanicExec { ready: false, })) } - - fn statistics(&self) -> Statistics { - unimplemented!() - } } /// A [`RecordBatchStream`] that yields every other batch and panics @@ -799,7 +790,7 @@ impl Stream for PanicStream { } else { self.ready = true; // get called again - cx.waker().clone().wake(); + cx.waker().wake_by_ref(); return Poll::Pending; } } diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs new file mode 100644 index 000000000000..9120566273d3 --- /dev/null +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -0,0 +1,644 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! TopK: Combination of Sort / LIMIT + +use arrow::{ + compute::interleave, + row::{RowConverter, Rows, SortField}, +}; +use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; + +use arrow_array::{Array, ArrayRef, RecordBatch}; +use arrow_schema::SchemaRef; +use datafusion_common::Result; +use datafusion_execution::{ + memory_pool::{MemoryConsumer, MemoryReservation}, + runtime_env::RuntimeEnv, +}; +use datafusion_physical_expr::PhysicalSortExpr; +use hashbrown::HashMap; + +use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; + +use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder}; + +/// Global TopK +/// +/// # Background +/// +/// "Top K" is a common query optimization used for queries such as +/// "find the top 3 customers by revenue". The (simplified) SQL for +/// such a query might be: +/// +/// ```sql +/// SELECT customer_id, revenue FROM 'sales.csv' ORDER BY revenue DESC limit 3; +/// ``` +/// +/// The simple plan would be: +/// +/// ```sql +/// > explain SELECT customer_id, revenue FROM sales ORDER BY revenue DESC limit 3; +/// +--------------+----------------------------------------+ +/// | plan_type | plan | +/// +--------------+----------------------------------------+ +/// | logical_plan | Limit: 3 | +/// | | Sort: revenue DESC NULLS FIRST | +/// | | Projection: customer_id, revenue | +/// | | TableScan: sales | +/// +--------------+----------------------------------------+ +/// ``` +/// +/// While this plan produces the correct answer, it will fully sorts the +/// input before discarding everything other than the top 3 elements. +/// +/// The same answer can be produced by simply keeping track of the top +/// K=3 elements, reducing the total amount of required buffer memory. +/// +/// # Structure +/// +/// This operator tracks the top K items using a `TopKHeap`. +pub struct TopK { + /// schema of the output (and the input) + schema: SchemaRef, + /// Runtime metrics + metrics: TopKMetrics, + /// Reservation + reservation: MemoryReservation, + /// The target number of rows for output batches + batch_size: usize, + /// sort expressions + expr: Arc<[PhysicalSortExpr]>, + /// row converter, for sort keys + row_converter: RowConverter, + /// scratch space for converting rows + scratch_rows: Rows, + /// stores the top k values and their sort key values, in order + heap: TopKHeap, +} + +impl TopK { + /// Create a new [`TopK`] that stores the top `k` values, as + /// defined by the sort expressions in `expr`. + // TOOD: make a builder or some other nicer API to avoid the + // clippy warning + #[allow(clippy::too_many_arguments)] + pub fn try_new( + partition_id: usize, + schema: SchemaRef, + expr: Vec, + k: usize, + batch_size: usize, + runtime: Arc, + metrics: &ExecutionPlanMetricsSet, + partition: usize, + ) -> Result { + let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]")) + .register(&runtime.memory_pool); + + let expr: Arc<[PhysicalSortExpr]> = expr.into(); + + let sort_fields: Vec<_> = expr + .iter() + .map(|e| { + Ok(SortField::new_with_options( + e.expr.data_type(&schema)?, + e.options, + )) + }) + .collect::>()?; + + // TODO there is potential to add special cases for single column sort fields + // to improve performance + let row_converter = RowConverter::new(sort_fields)?; + let scratch_rows = row_converter.empty_rows( + batch_size, + 20 * batch_size, // guestimate 20 bytes per row + ); + + Ok(Self { + schema: schema.clone(), + metrics: TopKMetrics::new(metrics, partition), + reservation, + batch_size, + expr, + row_converter, + scratch_rows, + heap: TopKHeap::new(k, batch_size, schema), + }) + } + + /// Insert `batch`, remembering if any of its values are among + /// the top k seen so far. + pub fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> { + // Updates on drop + let _timer = self.metrics.baseline.elapsed_compute().timer(); + + let sort_keys: Vec = self + .expr + .iter() + .map(|expr| { + let value = expr.expr.evaluate(&batch)?; + value.into_array(batch.num_rows()) + }) + .collect::>>()?; + + // reuse existing `Rows` to avoid reallocations + let rows = &mut self.scratch_rows; + rows.clear(); + self.row_converter.append(rows, &sort_keys)?; + + // TODO make this algorithmically better?: + // Idea: filter out rows >= self.heap.max() early (before passing to `RowConverter`) + // this avoids some work and also might be better vectorizable. + let mut batch_entry = self.heap.register_batch(batch); + for (index, row) in rows.iter().enumerate() { + match self.heap.max() { + // heap has k items, and the new row is greater than the + // current max in the heap ==> it is not a new topk + Some(max_row) if row.as_ref() >= max_row.row() => {} + // don't yet have k items or new item is lower than the currently k low values + None | Some(_) => { + self.heap.add(&mut batch_entry, row, index); + self.metrics.row_replacements.add(1); + } + } + } + self.heap.insert_batch_entry(batch_entry); + + // conserve memory + self.heap.maybe_compact()?; + + // update memory reservation + self.reservation.try_resize(self.size())?; + Ok(()) + } + + /// Returns the top k results broken into `batch_size` [`RecordBatch`]es, consuming the heap + pub fn emit(self) -> Result { + let Self { + schema, + metrics, + reservation: _, + batch_size, + expr: _, + row_converter: _, + scratch_rows: _, + mut heap, + } = self; + let _timer = metrics.baseline.elapsed_compute().timer(); // time updated on drop + + let mut batch = heap.emit()?; + metrics.baseline.output_rows().add(batch.num_rows()); + + // break into record batches as needed + let mut batches = vec![]; + loop { + if batch.num_rows() < batch_size { + batches.push(Ok(batch)); + break; + } else { + batches.push(Ok(batch.slice(0, batch_size))); + let remaining_length = batch.num_rows() - batch_size; + batch = batch.slice(batch_size, remaining_length); + } + } + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::iter(batches), + ))) + } + + /// return the size of memory used by this operator, in bytes + fn size(&self) -> usize { + std::mem::size_of::() + + self.row_converter.size() + + self.scratch_rows.size() + + self.heap.size() + } +} + +struct TopKMetrics { + /// metrics + pub baseline: BaselineMetrics, + + /// count of how many rows were replaced in the heap + pub row_replacements: Count, +} + +impl TopKMetrics { + fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + Self { + baseline: BaselineMetrics::new(metrics, partition), + row_replacements: MetricBuilder::new(metrics) + .counter("row_replacements", partition), + } + } +} + +/// This structure keeps at most the *smallest* k items, using the +/// [arrow::row] format for sort keys. While it is called "topK" for +/// values like `1, 2, 3, 4, 5` the "top 3" really means the +/// *smallest* 3 , `1, 2, 3`, not the *largest* 3 `3, 4, 5`. +/// +/// Using the `Row` format handles things such as ascending vs +/// descending and nulls first vs nulls last. +struct TopKHeap { + /// The maximum number of elemenents to store in this heap. + k: usize, + /// The target number of rows for output batches + batch_size: usize, + /// Storage for up at most `k` items using a BinaryHeap. Reverserd + /// so that the smallest k so far is on the top + inner: BinaryHeap, + /// Storage the original row values (TopKRow only has the sort key) + store: RecordBatchStore, + /// The size of all owned data held by this heap + owned_bytes: usize, +} + +impl TopKHeap { + fn new(k: usize, batch_size: usize, schema: SchemaRef) -> Self { + assert!(k > 0); + Self { + k, + batch_size, + inner: BinaryHeap::new(), + store: RecordBatchStore::new(schema), + owned_bytes: 0, + } + } + + /// Register a [`RecordBatch`] with the heap, returning the + /// appropriate entry + pub fn register_batch(&mut self, batch: RecordBatch) -> RecordBatchEntry { + self.store.register(batch) + } + + /// Insert a [`RecordBatchEntry`] created by a previous call to + /// [`Self::register_batch`] into storage. + pub fn insert_batch_entry(&mut self, entry: RecordBatchEntry) { + self.store.insert(entry) + } + + /// Returns the largest value stored by the heap if there are k + /// items, otherwise returns None. Remember this structure is + /// keeping the "smallest" k values + fn max(&self) -> Option<&TopKRow> { + if self.inner.len() < self.k { + None + } else { + self.inner.peek() + } + } + + /// Adds `row` to this heap. If inserting this new item would + /// increase the size past `k`, removes the previously smallest + /// item. + fn add( + &mut self, + batch_entry: &mut RecordBatchEntry, + row: impl AsRef<[u8]>, + index: usize, + ) { + let batch_id = batch_entry.id; + batch_entry.uses += 1; + + assert!(self.inner.len() <= self.k); + let row = row.as_ref(); + + // Reuse storage for evicted item if possible + let new_top_k = if self.inner.len() == self.k { + let prev_min = self.inner.pop().unwrap(); + + // Update batch use + if prev_min.batch_id == batch_entry.id { + batch_entry.uses -= 1; + } else { + self.store.unuse(prev_min.batch_id); + } + + // update memory accounting + self.owned_bytes -= prev_min.owned_size(); + prev_min.with_new_row(row, batch_id, index) + } else { + TopKRow::new(row, batch_id, index) + }; + + self.owned_bytes += new_top_k.owned_size(); + + // put the new row into the heap + self.inner.push(new_top_k) + } + + /// Returns the values stored in this heap, from values low to + /// high, as a single [`RecordBatch`], resetting the inner heap + pub fn emit(&mut self) -> Result { + Ok(self.emit_with_state()?.0) + } + + /// Returns the values stored in this heap, from values low to + /// high, as a single [`RecordBatch`], and a sorted vec of the + /// current heap's contents + pub fn emit_with_state(&mut self) -> Result<(RecordBatch, Vec)> { + let schema = self.store.schema().clone(); + + // generate sorted rows + let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec(); + + if self.store.is_empty() { + return Ok((RecordBatch::new_empty(schema), topk_rows)); + } + + // Indices for each row within its respective RecordBatch + let indices: Vec<_> = topk_rows + .iter() + .enumerate() + .map(|(i, k)| (i, k.index)) + .collect(); + + let num_columns = schema.fields().len(); + + // build the output columns one at time, using the + // `interleave` kernel to pick rows from different arrays + let output_columns: Vec<_> = (0..num_columns) + .map(|col| { + let input_arrays: Vec<_> = topk_rows + .iter() + .map(|k| { + let entry = + self.store.get(k.batch_id).expect("invalid stored batch id"); + entry.batch.column(col) as &dyn Array + }) + .collect(); + + // at this point `indices` contains indexes within the + // rows and `input_arrays` contains a reference to the + // relevant Array for that index. `interleave` pulls + // them together into a single new array + Ok(interleave(&input_arrays, &indices)?) + }) + .collect::>()?; + + let new_batch = RecordBatch::try_new(schema, output_columns)?; + Ok((new_batch, topk_rows)) + } + + /// Compact this heap, rewriting all stored batches into a single + /// input batch + pub fn maybe_compact(&mut self) -> Result<()> { + // we compact if the number of "unused" rows in the store is + // past some pre-defined threshold. Target holding up to + // around 20 batches, but handle cases of large k where some + // batches might be partially full + let max_unused_rows = (20 * self.batch_size) + self.k; + let unused_rows = self.store.unused_rows(); + + // don't compact if the store has one extra batch or + // unused rows is under the threshold + if self.store.len() <= 2 || unused_rows < max_unused_rows { + return Ok(()); + } + // at first, compact the entire thing always into a new batch + // (maybe we can get fancier in the future about ignoring + // batches that have a high usage ratio already + + // Note: new batch is in the same order as inner + let num_rows = self.inner.len(); + let (new_batch, mut topk_rows) = self.emit_with_state()?; + + // clear all old entires in store (this invalidates all + // store_ids in `inner`) + self.store.clear(); + + let mut batch_entry = self.register_batch(new_batch); + batch_entry.uses = num_rows; + + // rewrite all existing entries to use the new batch, and + // remove old entries. The sortedness and their relative + // position do not change + for (i, topk_row) in topk_rows.iter_mut().enumerate() { + topk_row.batch_id = batch_entry.id; + topk_row.index = i; + } + self.insert_batch_entry(batch_entry); + // restore the heap + self.inner = BinaryHeap::from(topk_rows); + + Ok(()) + } + + /// return the size of memory used by this heap, in bytes + fn size(&self) -> usize { + std::mem::size_of::() + + (self.inner.capacity() * std::mem::size_of::()) + + self.store.size() + + self.owned_bytes + } +} + +/// Represents one of the top K rows held in this heap. Orders +/// according to memcmp of row (e.g. the arrow Row format, but could +/// also be primtive values) +/// +/// Reuses allocations to minimize runtime overhead of creating new Vecs +#[derive(Debug, PartialEq)] +struct TopKRow { + /// the value of the sort key for this row. This contains the + /// bytes that could be stored in `OwnedRow` but uses `Vec` to + /// reuse allocations. + row: Vec, + /// the RecordBatch this row came from: an id into a [`RecordBatchStore`] + batch_id: u32, + /// the index in this record batch the row came from + index: usize, +} + +impl TopKRow { + /// Create a new TopKRow with new allocation + fn new(row: impl AsRef<[u8]>, batch_id: u32, index: usize) -> Self { + Self { + row: row.as_ref().to_vec(), + batch_id, + index, + } + } + + /// Create a new TopKRow reusing the existing allocation + fn with_new_row( + self, + new_row: impl AsRef<[u8]>, + batch_id: u32, + index: usize, + ) -> Self { + let Self { + mut row, + batch_id: _, + index: _, + } = self; + row.clear(); + row.extend_from_slice(new_row.as_ref()); + + Self { + row, + batch_id, + index, + } + } + + /// Returns the number of bytes owned by this row in the heap (not + /// including itself) + fn owned_size(&self) -> usize { + self.row.capacity() + } + + /// Returns a slice to the owned row value + fn row(&self) -> &[u8] { + self.row.as_slice() + } +} + +impl Eq for TopKRow {} + +impl PartialOrd for TopKRow { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for TopKRow { + fn cmp(&self, other: &Self) -> Ordering { + self.row.cmp(&other.row) + } +} + +#[derive(Debug)] +struct RecordBatchEntry { + id: u32, + batch: RecordBatch, + // for this batch, how many times has it been used + uses: usize, +} + +/// This structure tracks [`RecordBatch`] by an id so that: +/// +/// 1. The baches can be tracked via an id that can be copied cheaply +/// 2. The total memory held by all batches is tracked +#[derive(Debug)] +struct RecordBatchStore { + /// id generator + next_id: u32, + /// storage + batches: HashMap, + /// total size of all record batches tracked by this store + batches_size: usize, + /// schema of the batches + schema: SchemaRef, +} + +impl RecordBatchStore { + fn new(schema: SchemaRef) -> Self { + Self { + next_id: 0, + batches: HashMap::new(), + batches_size: 0, + schema, + } + } + + /// Register this batch with the store and assign an ID. No + /// attempt is made to compare this batch to other batches + pub fn register(&mut self, batch: RecordBatch) -> RecordBatchEntry { + let id = self.next_id; + self.next_id += 1; + RecordBatchEntry { id, batch, uses: 0 } + } + + /// Insert a record batch entry into this store, tracking its + /// memory use, if it has any uses + pub fn insert(&mut self, entry: RecordBatchEntry) { + // uses of 0 means that none of the rows in the batch were stored in the topk + if entry.uses > 0 { + self.batches_size += entry.batch.get_array_memory_size(); + self.batches.insert(entry.id, entry); + } + } + + /// Clear all values in this store, invalidating all previous batch ids + fn clear(&mut self) { + self.batches.clear(); + self.batches_size = 0; + } + + fn get(&self, id: u32) -> Option<&RecordBatchEntry> { + self.batches.get(&id) + } + + /// returns the total number of batches stored in this store + fn len(&self) -> usize { + self.batches.len() + } + + /// Returns the total number of rows in batches minus the number + /// which are in use + fn unused_rows(&self) -> usize { + self.batches + .values() + .map(|batch_entry| batch_entry.batch.num_rows() - batch_entry.uses) + .sum() + } + + /// returns true if the store has nothing stored + fn is_empty(&self) -> bool { + self.batches.is_empty() + } + + /// return the schema of batches stored + fn schema(&self) -> &SchemaRef { + &self.schema + } + + /// remove a use from the specified batch id. If the use count + /// reaches zero the batch entry is removed from the store + /// + /// panics if there were no remaining uses of id + pub fn unuse(&mut self, id: u32) { + let remove = if let Some(batch_entry) = self.batches.get_mut(&id) { + batch_entry.uses = batch_entry.uses.checked_sub(1).expect("underflow"); + batch_entry.uses == 0 + } else { + panic!("No entry for id {id}"); + }; + + if remove { + let old_entry = self.batches.remove(&id).unwrap(); + self.batches_size = self + .batches_size + .checked_sub(old_entry.batch.get_array_memory_size()) + .unwrap(); + } + } + + /// returns the size of memory used by this store, including all + /// referenced `RecordBatch`es, in bytes + pub fn size(&self) -> usize { + std::mem::size_of::() + + self.batches.capacity() + * (std::mem::size_of::() + std::mem::size_of::()) + + self.batches_size + } +} diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 7cc3cc7d59fe..94017efe97aa 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -50,7 +50,7 @@ pub fn create_aggregate_expr( Ok(Arc::new(AggregateFunctionExpr { fun: fun.clone(), args: input_phy_exprs.to_vec(), - data_type: (fun.return_type)(&input_exprs_types)?.as_ref().clone(), + data_type: fun.return_type(&input_exprs_types)?, name: name.into(), })) } @@ -83,7 +83,9 @@ impl AggregateExpr for AggregateFunctionExpr { } fn state_fields(&self) -> Result> { - let fields = (self.fun.state_type)(&self.data_type)? + let fields = self + .fun + .state_type(&self.data_type)? .iter() .enumerate() .map(|(i, data_type)| { @@ -103,11 +105,11 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_accumulator(&self) -> Result> { - (self.fun.accumulator)(&self.data_type) + self.fun.accumulator(&self.data_type) } fn create_sliding_accumulator(&self) -> Result> { - let accumulator = (self.fun.accumulator)(&self.data_type)?; + let accumulator = self.fun.accumulator(&self.data_type)?; // Accumulators that have window frame startings different // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index af765e257db2..d01ea5507449 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -21,31 +21,31 @@ //! The Union operator combines multiple inputs with the same schema +use std::borrow::Borrow; use std::pin::Pin; use std::task::{Context, Poll}; use std::{any::Any, sync::Arc}; -use arrow::{ - datatypes::{Field, Schema, SchemaRef}, - record_batch::RecordBatch, -}; -use datafusion_common::{exec_err, internal_err, DFSchemaRef, DataFusionError}; -use futures::Stream; -use itertools::Itertools; -use log::{debug, trace, warn}; - -use super::DisplayAs; use super::{ expressions::PhysicalSortExpr, metrics::{ExecutionPlanMetricsSet, MetricsSet}, - ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, Statistics, + ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, }; use crate::common::get_meet_of_orderings; +use crate::metrics::BaselineMetrics; use crate::stream::ObservedStream; -use crate::{expressions, metrics::BaselineMetrics}; -use datafusion_common::Result; + +use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion_common::stats::Precision; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::EquivalenceProperties; + +use futures::Stream; +use itertools::Itertools; +use log::{debug, trace, warn}; use tokio::macros::support::thread_rng_n; /// `UnionExec`: `UNION ALL` execution plan. @@ -96,38 +96,6 @@ pub struct UnionExec { } impl UnionExec { - /// Create a new UnionExec with specified schema. - /// The `schema` should always be a subset of the schema of `inputs`, - /// otherwise, an error will be returned. - pub fn try_new_with_schema( - inputs: Vec>, - schema: DFSchemaRef, - ) -> Result { - let mut exec = Self::new(inputs); - let exec_schema = exec.schema(); - let fields = schema - .fields() - .iter() - .map(|dff| { - exec_schema - .field_with_name(dff.name()) - .cloned() - .map_err(|_| { - DataFusionError::Internal(format!( - "Cannot find the field {:?} in child schema", - dff.name() - )) - }) - }) - .collect::>>()?; - let schema = Arc::new(Schema::new_with_metadata( - fields, - exec.schema().metadata().clone(), - )); - exec.schema = schema; - Ok(exec) - } - /// Create a new UnionExec pub fn new(inputs: Vec>) -> Self { let schema = union_schema(&inputs); @@ -224,6 +192,46 @@ impl ExecutionPlan for UnionExec { } } + fn equivalence_properties(&self) -> EquivalenceProperties { + // TODO: In some cases, we should be able to preserve some equivalence + // classes and constants. Add support for such cases. + let children_eqs = self + .inputs + .iter() + .map(|child| child.equivalence_properties()) + .collect::>(); + let mut result = EquivalenceProperties::new(self.schema()); + // Use the ordering equivalence class of the first child as the seed: + let mut meets = children_eqs[0] + .oeq_class() + .iter() + .map(|item| item.to_vec()) + .collect::>(); + // Iterate over all the children: + for child_eqs in &children_eqs[1..] { + // Compute meet orderings of the current meets and the new ordering + // equivalence class. + let mut idx = 0; + while idx < meets.len() { + // Find all the meets of `current_meet` with this child's orderings: + let valid_meets = child_eqs.oeq_class().iter().filter_map(|ordering| { + child_eqs.get_meet_ordering(ordering, &meets[idx]) + }); + // Use the longest of these meets as others are redundant: + if let Some(next_meet) = valid_meets.max_by_key(|m| m.len()) { + meets[idx] = next_meet; + idx += 1; + } else { + meets.swap_remove(idx); + } + } + } + // We know have all the valid orderings after union, remove redundant + // entries (implicitly) and return: + result.add_new_orderings(meets); + result + } + fn with_new_children( self: Arc, children: Vec>, @@ -264,12 +272,17 @@ impl ExecutionPlan for UnionExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - self.inputs + fn statistics(&self) -> Result { + let stats = self + .inputs .iter() - .map(|ep| ep.statistics()) + .map(|stat| stat.statistics()) + .collect::>>()?; + + Ok(stats + .into_iter() .reduce(stats_union) - .unwrap_or_default() + .unwrap_or_else(|| Statistics::new_unknown(&self.schema()))) } fn benefits_from_input_partitioning(&self) -> Vec { @@ -324,7 +337,7 @@ impl InterleaveExec { pub fn try_new(inputs: Vec>) -> Result { let schema = union_schema(&inputs); - if !can_interleave(&inputs) { + if !can_interleave(inputs.iter()) { return internal_err!( "Not all InterleaveExec children have a consistent hash partitioning" ); @@ -438,12 +451,17 @@ impl ExecutionPlan for InterleaveExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - self.inputs + fn statistics(&self) -> Result { + let stats = self + .inputs .iter() - .map(|ep| ep.statistics()) + .map(|stat| stat.statistics()) + .collect::>>()?; + + Ok(stats + .into_iter() .reduce(stats_union) - .unwrap_or_default() + .unwrap_or_else(|| Statistics::new_unknown(&self.schema()))) } fn benefits_from_input_partitioning(&self) -> Vec { @@ -457,17 +475,18 @@ impl ExecutionPlan for InterleaveExec { /// It might be too strict here in the case that the input partition specs are compatible but not exactly the same. /// For example one input partition has the partition spec Hash('a','b','c') and /// other has the partition spec Hash('a'), It is safe to derive the out partition with the spec Hash('a','b','c'). -pub fn can_interleave(inputs: &[Arc]) -> bool { - if inputs.is_empty() { +pub fn can_interleave>>( + mut inputs: impl Iterator, +) -> bool { + let Some(first) = inputs.next() else { return false; - } + }; - let first_input_partition = inputs[0].output_partitioning(); - matches!(first_input_partition, Partitioning::Hash(_, _)) + let reference = first.borrow().output_partitioning(); + matches!(reference, Partitioning::Hash(_, _)) && inputs - .iter() - .map(|plan| plan.output_partitioning()) - .all(|partition| partition == first_input_partition) + .map(|plan| plan.borrow().output_partitioning()) + .all(|partition| partition == reference) } fn union_schema(inputs: &[Arc]) -> SchemaRef { @@ -564,49 +583,65 @@ fn col_stats_union( mut left: ColumnStatistics, right: ColumnStatistics, ) -> ColumnStatistics { - left.distinct_count = None; - left.min_value = left - .min_value - .zip(right.min_value) - .map(|(a, b)| expressions::helpers::min(&a, &b)) - .and_then(Result::ok); - left.max_value = left - .max_value - .zip(right.max_value) - .map(|(a, b)| expressions::helpers::max(&a, &b)) - .and_then(Result::ok); - left.null_count = left.null_count.zip(right.null_count).map(|(a, b)| a + b); + left.distinct_count = Precision::Absent; + left.min_value = left.min_value.min(&right.min_value); + left.max_value = left.max_value.max(&right.max_value); + left.null_count = left.null_count.add(&right.null_count); left } fn stats_union(mut left: Statistics, right: Statistics) -> Statistics { - left.is_exact = left.is_exact && right.is_exact; - left.num_rows = left.num_rows.zip(right.num_rows).map(|(a, b)| a + b); - left.total_byte_size = left - .total_byte_size - .zip(right.total_byte_size) - .map(|(a, b)| a + b); - left.column_statistics = - left.column_statistics - .zip(right.column_statistics) - .map(|(a, b)| { - a.into_iter() - .zip(b) - .map(|(ca, cb)| col_stats_union(ca, cb)) - .collect() - }); + left.num_rows = left.num_rows.add(&right.num_rows); + left.total_byte_size = left.total_byte_size.add(&right.total_byte_size); + left.column_statistics = left + .column_statistics + .into_iter() + .zip(right.column_statistics) + .map(|(a, b)| col_stats_union(a, b)) + .collect::>(); left } #[cfg(test)] mod tests { use super::*; + use crate::collect; + use crate::memory::MemoryExec; use crate::test; - use crate::collect; use arrow::record_batch::RecordBatch; + use arrow_schema::{DataType, SortOptions}; use datafusion_common::ScalarValue; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr::PhysicalExpr; + + // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g) + fn create_test_schema() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let g = Field::new("g", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g])); + + Ok(schema) + } + + // Convert each tuple to PhysicalSortExpr + fn convert_to_sort_exprs( + in_data: &[(&Arc, SortOptions)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: (*expr).clone(), + options: *options, + }) + .collect::>() + } #[tokio::test] async fn test_union_partitions() -> Result<()> { @@ -630,84 +665,182 @@ mod tests { #[tokio::test] async fn test_stats_union() { let left = Statistics { - is_exact: true, - num_rows: Some(5), - total_byte_size: Some(23), - column_statistics: Some(vec![ + num_rows: Precision::Exact(5), + total_byte_size: Precision::Exact(23), + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(5), - max_value: Some(ScalarValue::Int64(Some(21))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: Some(0), + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(0), }, ColumnStatistics { - distinct_count: Some(1), - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: Some(3), + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Exact(3), }, ColumnStatistics { - distinct_count: None, - max_value: Some(ScalarValue::Float32(Some(1.1))), - min_value: Some(ScalarValue::Float32(Some(0.1))), - null_count: None, + distinct_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))), + min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))), + null_count: Precision::Absent, }, - ]), + ], }; let right = Statistics { - is_exact: true, - num_rows: Some(7), - total_byte_size: Some(29), - column_statistics: Some(vec![ + num_rows: Precision::Exact(7), + total_byte_size: Precision::Exact(29), + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(3), - max_value: Some(ScalarValue::Int64(Some(34))), - min_value: Some(ScalarValue::Int64(Some(1))), - null_count: Some(1), + distinct_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int64(Some(34))), + min_value: Precision::Exact(ScalarValue::Int64(Some(1))), + null_count: Precision::Exact(1), }, ColumnStatistics { - distinct_count: None, - max_value: Some(ScalarValue::Utf8(Some(String::from("c")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("b")))), - null_count: None, + distinct_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::from("c")), + min_value: Precision::Exact(ScalarValue::from("b")), + null_count: Precision::Absent, }, ColumnStatistics { - distinct_count: None, - max_value: None, - min_value: None, - null_count: None, + distinct_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + null_count: Precision::Absent, }, - ]), + ], }; let result = stats_union(left, right); let expected = Statistics { - is_exact: true, - num_rows: Some(12), - total_byte_size: Some(52), - column_statistics: Some(vec![ + num_rows: Precision::Exact(12), + total_byte_size: Precision::Exact(52), + column_statistics: vec![ ColumnStatistics { - distinct_count: None, - max_value: Some(ScalarValue::Int64(Some(34))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: Some(1), + distinct_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Int64(Some(34))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(1), }, ColumnStatistics { - distinct_count: None, - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: None, + distinct_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Absent, }, ColumnStatistics { - distinct_count: None, - max_value: None, - min_value: None, - null_count: None, + distinct_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + null_count: Precision::Absent, }, - ]), + ], }; assert_eq!(result, expected); } + + #[tokio::test] + async fn test_union_equivalence_properties() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + let col_f = &col("f", &schema)?; + let options = SortOptions::default(); + let test_cases = vec![ + //-----------TEST CASE 1----------// + ( + // First child orderings + vec![ + // [a ASC, b ASC, f ASC] + vec![(col_a, options), (col_b, options), (col_f, options)], + ], + // Second child orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![(col_a, options), (col_b, options), (col_c, options)], + // [a ASC, b ASC, f ASC] + vec![(col_a, options), (col_b, options), (col_f, options)], + ], + // Union output orderings + vec![ + // [a ASC, b ASC, f ASC] + vec![(col_a, options), (col_b, options), (col_f, options)], + ], + ), + //-----------TEST CASE 2----------// + ( + // First child orderings + vec![ + // [a ASC, b ASC, f ASC] + vec![(col_a, options), (col_b, options), (col_f, options)], + // d ASC + vec![(col_d, options)], + ], + // Second child orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![(col_a, options), (col_b, options), (col_c, options)], + // [e ASC] + vec![(col_e, options)], + ], + // Union output orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, options), (col_b, options)], + ], + ), + ]; + + for ( + test_idx, + (first_child_orderings, second_child_orderings, union_orderings), + ) in test_cases.iter().enumerate() + { + let first_orderings = first_child_orderings + .iter() + .map(|ordering| convert_to_sort_exprs(ordering)) + .collect::>(); + let second_orderings = second_child_orderings + .iter() + .map(|ordering| convert_to_sort_exprs(ordering)) + .collect::>(); + let union_expected_orderings = union_orderings + .iter() + .map(|ordering| convert_to_sort_exprs(ordering)) + .collect::>(); + let child1 = Arc::new( + MemoryExec::try_new(&[], schema.clone(), None)? + .with_sort_information(first_orderings), + ); + let child2 = Arc::new( + MemoryExec::try_new(&[], schema.clone(), None)? + .with_sort_information(second_orderings), + ); + + let union = UnionExec::new(vec![child1, child2]); + let union_eq_properties = union.equivalence_properties(); + let union_actual_orderings = union_eq_properties.oeq_class(); + let err_msg = format!( + "Error in test id: {:?}, test case: {:?}", + test_idx, test_cases[test_idx] + ); + assert_eq!( + union_actual_orderings.len(), + union_expected_orderings.len(), + "{}", + err_msg + ); + for expected in &union_expected_orderings { + assert!(union_actual_orderings.contains(expected), "{}", err_msg); + } + } + Ok(()) + } } diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index 410ea97887e0..b9e732c317af 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -17,6 +17,14 @@ //! Defines the unnest column plan for unnesting values in a column that contains a list //! type, conceptually is like joining each row with all the values in the list column. +use std::{any::Any, sync::Arc}; + +use super::DisplayAs; +use crate::{ + expressions::Column, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + PhysicalExpr, PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream, +}; + use arrow::array::{ Array, ArrayRef, ArrowPrimitiveType, FixedSizeListArray, LargeListArray, ListArray, PrimitiveArray, @@ -27,23 +35,14 @@ use arrow::datatypes::{ }; use arrow::record_batch::RecordBatch; use arrow_array::{GenericListArray, OffsetSizeTrait}; -use async_trait::async_trait; -use datafusion_common::UnnestOptions; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{exec_err, DataFusionError, Result, UnnestOptions}; use datafusion_execution::TaskContext; -use futures::Stream; -use futures::StreamExt; -use log::trace; -use std::time::Instant; -use std::{any::Any, sync::Arc}; -use crate::{ - expressions::Column, DisplayFormatType, Distribution, EquivalenceProperties, - ExecutionPlan, Partitioning, PhysicalExpr, PhysicalSortExpr, RecordBatchStream, - SendableRecordBatchStream, Statistics, -}; +use async_trait::async_trait; +use futures::{Stream, StreamExt}; +use log::trace; -use super::DisplayAs; +use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; /// Unnest the given column by joining the row with each value in the /// nested type. @@ -59,6 +58,8 @@ pub struct UnnestExec { column: Column, /// Options options: UnnestOptions, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, } impl UnnestExec { @@ -74,6 +75,7 @@ impl UnnestExec { schema, column, options, + metrics: Default::default(), } } } @@ -136,32 +138,63 @@ impl ExecutionPlan for UnnestExec { None } - fn equivalence_properties(&self) -> EquivalenceProperties { - self.input.equivalence_properties() - } - fn execute( &self, partition: usize, context: Arc, ) -> Result { let input = self.input.execute(partition, context)?; + let metrics = UnnestMetrics::new(partition, &self.metrics); Ok(Box::pin(UnnestStream { input, schema: self.schema.clone(), column: self.column.clone(), options: self.options.clone(), - num_input_batches: 0, - num_input_rows: 0, - num_output_batches: 0, - num_output_rows: 0, - unnest_time: 0, + metrics, })) } - fn statistics(&self) -> Statistics { - Default::default() + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +#[derive(Clone, Debug)] +struct UnnestMetrics { + /// total time for column unnesting + elapsed_compute: metrics::Time, + /// Number of batches consumed + input_batches: metrics::Count, + /// Number of rows consumed + input_rows: metrics::Count, + /// Number of batches produced + output_batches: metrics::Count, + /// Number of rows produced by this operator + output_rows: metrics::Count, +} + +impl UnnestMetrics { + fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let elapsed_compute = MetricBuilder::new(metrics).elapsed_compute(partition); + + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + input_batches, + input_rows, + output_batches, + output_rows, + elapsed_compute, + } } } @@ -175,16 +208,8 @@ struct UnnestStream { column: Column, /// Options options: UnnestOptions, - /// number of input batches - num_input_batches: usize, - /// number of input rows - num_input_rows: usize, - /// number of batches produced - num_output_batches: usize, - /// number of rows produced - num_output_rows: usize, - /// total time for column unnesting, in ms - unnest_time: usize, + /// Metrics + metrics: UnnestMetrics, } impl RecordBatchStream for UnnestStream { @@ -216,15 +241,15 @@ impl UnnestStream { .poll_next_unpin(cx) .map(|maybe_batch| match maybe_batch { Some(Ok(batch)) => { - let start = Instant::now(); + let timer = self.metrics.elapsed_compute.timer(); let result = build_batch(&batch, &self.schema, &self.column, &self.options); - self.num_input_batches += 1; - self.num_input_rows += batch.num_rows(); + self.metrics.input_batches.add(1); + self.metrics.input_rows.add(batch.num_rows()); if let Ok(ref batch) = result { - self.unnest_time += start.elapsed().as_millis() as usize; - self.num_output_batches += 1; - self.num_output_rows += batch.num_rows(); + timer.done(); + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); } Some(result) @@ -232,12 +257,12 @@ impl UnnestStream { other => { trace!( "Processed {} probe-side input batches containing {} rows and \ - produced {} output batches containing {} rows in {} ms", - self.num_input_batches, - self.num_input_rows, - self.num_output_batches, - self.num_output_rows, - self.unnest_time, + produced {} output batches containing {} rows in {}", + self.metrics.input_batches, + self.metrics.input_rows, + self.metrics.output_batches, + self.metrics.output_rows, + self.metrics.elapsed_compute, ); other } @@ -251,7 +276,7 @@ fn build_batch( column: &Column, options: &UnnestOptions, ) -> Result { - let list_array = column.evaluate(batch)?.into_array(batch.num_rows()); + let list_array = column.evaluate(batch)?.into_array(batch.num_rows())?; match list_array.data_type() { DataType::List(_) => { let list_array = list_array.as_any().downcast_ref::().unwrap(); diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index 2cf341d1fe60..b624fb362e65 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -17,20 +17,21 @@ //! Values execution plan +use std::any::Any; +use std::sync::Arc; + use super::expressions::PhysicalSortExpr; use super::{common, DisplayAs, SendableRecordBatchStream, Statistics}; use crate::{ memory::MemoryStream, ColumnarValue, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, }; + use arrow::array::new_null_array; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::{internal_err, plan_err, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result, ScalarValue}; use datafusion_execution::TaskContext; -use std::any::Any; -use std::sync::Arc; /// Execution plan for values list based relation (produces constant rows) #[derive(Debug)] @@ -66,10 +67,11 @@ impl ValuesExec { (0..n_row) .map(|i| { let r = data[i][j].evaluate(&batch); + match r { Ok(ColumnarValue::Scalar(scalar)) => Ok(scalar), Ok(ColumnarValue::Array(a)) if a.len() == 1 => { - ScalarValue::try_from_array(&a, 0) + Ok(ScalarValue::List(a)) } Ok(ColumnarValue::Array(a)) => { plan_err!( @@ -186,9 +188,13 @@ impl ExecutionPlan for ValuesExec { )?)) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { let batch = self.data(); - common::compute_record_batch_statistics(&[batch], &self.schema, None) + Ok(common::compute_record_batch_statistics( + &[batch], + &self.schema, + None, + )) } } @@ -196,6 +202,7 @@ impl ExecutionPlan for ValuesExec { mod tests { use super::*; use crate::test::{self, make_partition}; + use arrow_schema::{DataType, Field, Schema}; #[tokio::test] diff --git a/datafusion/physical-plan/src/visitor.rs b/datafusion/physical-plan/src/visitor.rs index 573e4f8b02be..ca826c50022d 100644 --- a/datafusion/physical-plan/src/visitor.rs +++ b/datafusion/physical-plan/src/visitor.rs @@ -38,7 +38,7 @@ pub fn accept( /// depth first walk of `ExecutionPlan` nodes. `pre_visit` is called /// before any children are visited, and then `post_visit` is called /// after all children have been visited. -//// +/// /// To use, define a struct that implements this trait and then invoke /// ['accept']. /// diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 4108b4220599..0871ec0d7ff3 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -30,34 +30,36 @@ use std::task::{Context, Poll}; use crate::expressions::PhysicalSortExpr; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::windows::{ - calc_requirements, get_ordered_partition_by_indices, window_ordering_equivalence, + calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs, + window_equivalence_properties, }; use crate::{ ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, - Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, + InputOrderMode, Partitioning, RecordBatchStream, SendableRecordBatchStream, + Statistics, WindowExpr, }; use arrow::{ - array::{Array, ArrayRef, UInt32Builder}, + array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder}, compute::{concat, concat_batches, sort_to_indices}, datatypes::{Schema, SchemaBuilder, SchemaRef}, record_batch::RecordBatch, }; +use datafusion_common::hash_utils::create_hashes; +use datafusion_common::stats::Precision; use datafusion_common::utils::{ evaluate_partition_ranges, get_arrayref_at_indices, get_at_indices, get_record_batch_at_indices, get_row_at_idx, }; -use datafusion_common::{exec_err, plan_err, DataFusionError, Result}; +use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_expr::window_state::{PartitionBatchState, WindowAggState}; use datafusion_expr::ColumnarValue; -use datafusion_physical_expr::hash_utils::create_hashes; use datafusion_physical_expr::window::{ PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowState, }; use datafusion_physical_expr::{ - EquivalenceProperties, OrderingEquivalenceProperties, PhysicalExpr, - PhysicalSortRequirement, + EquivalenceProperties, PhysicalExpr, PhysicalSortRequirement, }; use ahash::RandomState; @@ -67,17 +69,6 @@ use hashbrown::raw::RawTable; use indexmap::IndexMap; use log::debug; -#[derive(Debug, Clone, PartialEq)] -/// Specifies partition column properties in terms of input ordering -pub enum PartitionSearchMode { - /// None of the columns among the partition columns is ordered. - Linear, - /// Some columns of the partition columns are ordered but not all - PartiallySorted(Vec), - /// All Partition columns are ordered (Also empty case) - Sorted, -} - /// Window execution plan #[derive(Debug)] pub struct BoundedWindowAggExec { @@ -87,14 +78,12 @@ pub struct BoundedWindowAggExec { window_expr: Vec>, /// Schema after the window is run schema: SchemaRef, - /// Schema before the window - input_schema: SchemaRef, /// Partition Keys pub partition_keys: Vec>, /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// Partition by search mode - pub partition_search_mode: PartitionSearchMode, + /// Describes how the input is ordered relative to the partition keys + pub input_order_mode: InputOrderMode, /// Partition by indices that define ordering // For example, if input ordering is ORDER BY a, b and window expression // contains PARTITION BY b, a; `ordered_partition_by_indices` would be 1, 0. @@ -109,15 +98,14 @@ impl BoundedWindowAggExec { pub fn try_new( window_expr: Vec>, input: Arc, - input_schema: SchemaRef, partition_keys: Vec>, - partition_search_mode: PartitionSearchMode, + input_order_mode: InputOrderMode, ) -> Result { - let schema = create_schema(&input_schema, &window_expr)?; + let schema = create_schema(&input.schema(), &window_expr)?; let schema = Arc::new(schema); let partition_by_exprs = window_expr[0].partition_by(); - let ordered_partition_by_indices = match &partition_search_mode { - PartitionSearchMode::Sorted => { + let ordered_partition_by_indices = match &input_order_mode { + InputOrderMode::Sorted => { let indices = get_ordered_partition_by_indices( window_expr[0].partition_by(), &input, @@ -128,10 +116,8 @@ impl BoundedWindowAggExec { (0..partition_by_exprs.len()).collect::>() } } - PartitionSearchMode::PartiallySorted(ordered_indices) => { - ordered_indices.clone() - } - PartitionSearchMode::Linear => { + InputOrderMode::PartiallySorted(ordered_indices) => ordered_indices.clone(), + InputOrderMode::Linear => { vec![] } }; @@ -139,10 +125,9 @@ impl BoundedWindowAggExec { input, window_expr, schema, - input_schema, partition_keys, metrics: ExecutionPlanMetricsSet::new(), - partition_search_mode, + input_order_mode, ordered_partition_by_indices, }) } @@ -157,20 +142,18 @@ impl BoundedWindowAggExec { &self.input } - /// Get the input schema before any window functions are applied - pub fn input_schema(&self) -> SchemaRef { - self.input_schema.clone() - } - /// Return the output sort order of partition keys: For example /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a // We are sure that partition by columns are always at the beginning of sort_keys // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely // to calculate partition separation points pub fn partition_by_sort_keys(&self) -> Result> { - // Partition by sort keys indices are stored in self.ordered_partition_by_indices. - let sort_keys = self.input.output_ordering().unwrap_or(&[]); - get_at_indices(sort_keys, &self.ordered_partition_by_indices) + let partition_by = self.window_expr()[0].partition_by(); + get_partition_by_sort_exprs( + &self.input, + partition_by, + &self.ordered_partition_by_indices, + ) } /// Initializes the appropriate [`PartitionSearcher`] implementation from @@ -178,8 +161,8 @@ impl BoundedWindowAggExec { fn get_search_algo(&self) -> Result> { let partition_by_sort_keys = self.partition_by_sort_keys()?; let ordered_partition_by_indices = self.ordered_partition_by_indices.clone(); - Ok(match &self.partition_search_mode { - PartitionSearchMode::Sorted => { + Ok(match &self.input_order_mode { + InputOrderMode::Sorted => { // In Sorted mode, all partition by columns should be ordered. if self.window_expr()[0].partition_by().len() != ordered_partition_by_indices.len() @@ -191,7 +174,7 @@ impl BoundedWindowAggExec { ordered_partition_by_indices, }) } - PartitionSearchMode::Linear | PartitionSearchMode::PartiallySorted(_) => { + InputOrderMode::Linear | InputOrderMode::PartiallySorted(_) => { Box::new(LinearSearch::new(ordered_partition_by_indices)) } }) @@ -219,7 +202,7 @@ impl DisplayAs for BoundedWindowAggExec { ) }) .collect(); - let mode = &self.partition_search_mode; + let mode = &self.input_order_mode; write!(f, "wdw=[{}], mode=[{:?}]", g.join(", "), mode)?; } } @@ -260,7 +243,7 @@ impl ExecutionPlan for BoundedWindowAggExec { fn required_input_ordering(&self) -> Vec>> { let partition_bys = self.window_expr()[0].partition_by(); let order_keys = self.window_expr()[0].order_by(); - if self.partition_search_mode != PartitionSearchMode::Sorted + if self.input_order_mode != InputOrderMode::Sorted || self.ordered_partition_by_indices.len() >= partition_bys.len() { let partition_bys = self @@ -282,13 +265,9 @@ impl ExecutionPlan for BoundedWindowAggExec { } } + /// Get the [`EquivalenceProperties`] within the plan fn equivalence_properties(&self) -> EquivalenceProperties { - self.input().equivalence_properties() - } - - /// Get the OrderingEquivalenceProperties within the plan - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - window_ordering_equivalence(&self.schema, &self.input, &self.window_expr) + window_equivalence_properties(&self.schema, &self.input, &self.window_expr) } fn maintains_input_order(&self) -> Vec { @@ -302,9 +281,8 @@ impl ExecutionPlan for BoundedWindowAggExec { Ok(Arc::new(BoundedWindowAggExec::try_new( self.window_expr.clone(), children[0].clone(), - self.input_schema.clone(), self.partition_keys.clone(), - self.partition_search_mode.clone(), + self.input_order_mode.clone(), )?)) } @@ -329,24 +307,22 @@ impl ExecutionPlan for BoundedWindowAggExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - let input_stat = self.input.statistics(); + fn statistics(&self) -> Result { + let input_stat = self.input.statistics()?; let win_cols = self.window_expr.len(); - let input_cols = self.input_schema.fields().len(); + let input_cols = self.input.schema().fields().len(); // TODO stats: some windowing function will maintain invariants such as min, max... let mut column_statistics = Vec::with_capacity(win_cols + input_cols); - if let Some(input_col_stats) = input_stat.column_statistics { - column_statistics.extend(input_col_stats); - } else { - column_statistics.extend(vec![ColumnStatistics::default(); input_cols]); + // copy stats of the input to the beginning of the schema. + column_statistics.extend(input_stat.column_statistics); + for _ in 0..win_cols { + column_statistics.push(ColumnStatistics::new_unknown()) } - column_statistics.extend(vec![ColumnStatistics::default(); win_cols]); - Statistics { - is_exact: input_stat.is_exact, + Ok(Statistics { num_rows: input_stat.num_rows, - column_statistics: Some(column_statistics), - total_byte_size: None, - } + column_statistics, + total_byte_size: Precision::Absent, + }) } } @@ -523,7 +499,7 @@ impl PartitionSearcher for LinearSearch { .iter() .map(|items| { concat(&items.iter().map(|e| e.as_ref()).collect::>()) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect::>>()?; // We should emit columns according to row index ordering. @@ -609,7 +585,7 @@ impl LinearSearch { .map(|item| match item.evaluate(record_batch)? { ColumnarValue::Array(array) => Ok(array), ColumnarValue::Scalar(scalar) => { - plan_err!("Sort operation is not applicable to scalar value {scalar}") + scalar.to_array_of_size(record_batch.num_rows()) } }) .collect() @@ -1050,8 +1026,11 @@ impl BoundedWindowAggStream { .iter() .map(|elem| elem.slice(n_out, n_to_keep)) .collect::>(); - self.input_buffer = - RecordBatch::try_new(self.input_buffer.schema(), batch_to_keep)?; + self.input_buffer = RecordBatch::try_new_with_options( + self.input_buffer.schema(), + batch_to_keep, + &RecordBatchOptions::new().with_row_count(Some(n_to_keep)), + )?; Ok(()) } @@ -1132,3 +1111,131 @@ fn get_aggregate_result_out_column( result .ok_or_else(|| DataFusionError::Execution("Should contain something".to_string())) } + +#[cfg(test)] +mod tests { + use crate::common::collect; + use crate::memory::MemoryExec; + use crate::windows::{BoundedWindowAggExec, InputOrderMode}; + use crate::{get_plan_string, ExecutionPlan}; + use arrow_array::RecordBatch; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::{assert_batches_eq, Result, ScalarValue}; + use datafusion_execution::config::SessionConfig; + use datafusion_execution::TaskContext; + use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr::expressions::NthValue; + use datafusion_physical_expr::window::BuiltInWindowExpr; + use datafusion_physical_expr::window::BuiltInWindowFunctionExpr; + use std::sync::Arc; + + // Tests NTH_VALUE(negative index) with memoize feature. + // To be able to trigger memoize feature for NTH_VALUE we need to + // - feed BoundedWindowAggExec with batch stream data. + // - Window frame should contain UNBOUNDED PRECEDING. + // It hard to ensure these conditions are met, from the sql query. + #[tokio::test] + async fn test_window_nth_value_bounded_memoize() -> Result<()> { + let config = SessionConfig::new().with_target_partitions(1); + let task_ctx = Arc::new(TaskContext::default().with_session_config(config)); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + // Create a new batch of data to insert into the table + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))], + )?; + + let memory_exec = MemoryExec::try_new( + &[vec![batch.clone(), batch.clone(), batch.clone()]], + schema.clone(), + None, + ) + .map(|e| Arc::new(e) as Arc)?; + let col_a = col("a", &schema)?; + let nth_value_func1 = + NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1)? + .reverse_expr() + .unwrap(); + let nth_value_func2 = + NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2)? + .reverse_expr() + .unwrap(); + let last_value_func = + Arc::new(NthValue::last("last", col_a.clone(), DataType::Int32)) as _; + let window_exprs = vec![ + // LAST_VALUE(a) + Arc::new(BuiltInWindowExpr::new( + last_value_func, + &[], + &[], + Arc::new(WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + end_bound: WindowFrameBound::CurrentRow, + }), + )) as _, + // NTH_VALUE(a, -1) + Arc::new(BuiltInWindowExpr::new( + nth_value_func1, + &[], + &[], + Arc::new(WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + end_bound: WindowFrameBound::CurrentRow, + }), + )) as _, + // NTH_VALUE(a, -2) + Arc::new(BuiltInWindowExpr::new( + nth_value_func2, + &[], + &[], + Arc::new(WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + end_bound: WindowFrameBound::CurrentRow, + }), + )) as _, + ]; + let physical_plan = BoundedWindowAggExec::try_new( + window_exprs, + memory_exec, + vec![], + InputOrderMode::Sorted, + ) + .map(|e| Arc::new(e) as Arc)?; + + let batches = collect(physical_plan.execute(0, task_ctx)?).await?; + + let expected = vec![ + "BoundedWindowAggExec: wdw=[last: Ok(Field { name: \"last\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }, nth_value(-1): Ok(Field { name: \"nth_value(-1)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }, nth_value(-2): Ok(Field { name: \"nth_value(-2)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted]", + " MemoryExec: partitions=1, partition_sizes=[3]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = [ + "+---+------+---------------+---------------+", + "| a | last | nth_value(-1) | nth_value(-2) |", + "+---+------+---------------+---------------+", + "| 1 | 1 | 1 | |", + "| 2 | 2 | 2 | 1 |", + "| 3 | 3 | 3 | 2 |", + "| 1 | 1 | 1 | 3 |", + "| 2 | 2 | 2 | 1 |", + "| 3 | 3 | 3 | 2 |", + "| 1 | 1 | 1 | 3 |", + "| 2 | 2 | 2 | 1 |", + "| 3 | 3 | 3 | 2 |", + "+---+------+---------------+---------------+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 0f165f79354e..fec168fabf48 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -27,34 +27,27 @@ use crate::{ cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile, PhysicalSortExpr, RowNumber, }, - udaf, unbounded_output, ExecutionPlan, PhysicalExpr, + udaf, unbounded_output, ExecutionPlan, InputOrderMode, PhysicalExpr, }; use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; -use datafusion_common::utils::{ - find_indices, get_at_indices, is_sorted, longest_consecutive_prefix, - merge_and_order_indices, set_difference, -}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ - window_function::{BuiltInWindowFunction, WindowFunction}, - PartitionEvaluator, WindowFrame, WindowUDF, + BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, + WindowUDF, }; +use datafusion_physical_expr::equivalence::collapse_lex_req; use datafusion_physical_expr::{ - equivalence::OrderingEquivalenceBuilder, - utils::{convert_to_expr, get_indices_of_matching_exprs}, + reverse_order_bys, window::{BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr}, - AggregateExpr, OrderingEquivalenceProperties, PhysicalSortRequirement, + AggregateExpr, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, }; -use itertools::{izip, Itertools}; - mod bounded_window_agg_exec; mod window_agg_exec; pub use bounded_window_agg_exec::BoundedWindowAggExec; -pub use bounded_window_agg_exec::PartitionSearchMode; pub use window_agg_exec::WindowAggExec; pub use datafusion_physical_expr::window::{ @@ -63,7 +56,7 @@ pub use datafusion_physical_expr::window::{ /// Create a physical expression for window function pub fn create_window_expr( - fun: &WindowFunction, + fun: &WindowFunctionDefinition, name: String, args: &[Arc], partition_by: &[Arc], @@ -72,7 +65,7 @@ pub fn create_window_expr( input_schema: &Schema, ) -> Result> { Ok(match fun { - WindowFunction::AggregateFunction(fun) => { + WindowFunctionDefinition::AggregateFunction(fun) => { let aggregate = aggregates::create_aggregate_expr( fun, false, @@ -88,13 +81,15 @@ pub fn create_window_expr( aggregate, ) } - WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr::new( - create_built_in_window_expr(fun, args, input_schema, name)?, - partition_by, - order_by, - window_frame, - )), - WindowFunction::AggregateUDF(fun) => { + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + Arc::new(BuiltInWindowExpr::new( + create_built_in_window_expr(fun, args, input_schema, name)?, + partition_by, + order_by, + window_frame, + )) + } + WindowFunctionDefinition::AggregateUDF(fun) => { let aggregate = udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?; window_expr_from_aggregate_expr( @@ -104,7 +99,7 @@ pub fn create_window_expr( aggregate, ) } - WindowFunction::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( + WindowFunctionDefinition::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( create_udwf_window_expr(fun, args, input_schema, name)?, partition_by, order_by, @@ -172,15 +167,26 @@ fn create_built_in_window_expr( BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name)), BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name)), BuiltInWindowFunction::Ntile => { - let n: i64 = get_scalar_value_from_args(args, 0)? - .ok_or_else(|| { - DataFusionError::Execution( - "NTILE requires at least 1 argument".to_string(), - ) - })? - .try_into()?; - let n: u64 = n as u64; - Arc::new(Ntile::new(name, n)) + let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| { + DataFusionError::Execution( + "NTILE requires a positive integer".to_string(), + ) + })?; + + if n.is_null() { + return exec_err!("NTILE requires a positive integer, but finds NULL"); + } + + if n.is_unsigned() { + let n: u64 = n.try_into()?; + Arc::new(Ntile::new(name, n)) + } else { + let n: i64 = n.try_into()?; + if n <= 0 { + return exec_err!("NTILE requires a positive integer"); + } + Arc::new(Ntile::new(name, n as u64)) + } } BuiltInWindowFunction::Lag => { let arg = args[0].clone(); @@ -238,7 +244,7 @@ fn create_udwf_window_expr( .collect::>()?; // figure out the output type - let data_type = (fun.return_type)(&input_types)?; + let data_type = fun.return_type(&input_types)?; Ok(Arc::new(WindowUDFExpr { fun: Arc::clone(fun), args: args.to_vec(), @@ -255,7 +261,7 @@ struct WindowUDFExpr { /// Display name name: String, /// result type - data_type: Arc, + data_type: DataType, } impl BuiltInWindowFunctionExpr for WindowUDFExpr { @@ -265,11 +271,7 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { fn field(&self) -> Result { let nullable = true; - Ok(Field::new( - &self.name, - self.data_type.as_ref().clone(), - nullable, - )) + Ok(Field::new(&self.name, self.data_type.clone(), nullable)) } fn expressions(&self) -> Vec> { @@ -277,7 +279,7 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } fn create_evaluator(&self) -> Result> { - (self.fun.partition_evaluator_factory)() + self.fun.partition_evaluator_factory() } fn name(&self) -> &str { @@ -321,45 +323,51 @@ pub(crate) fn get_ordered_partition_by_indices( partition_by_exprs: &[Arc], input: &Arc, ) -> Vec { - let input_ordering = input.output_ordering().unwrap_or(&[]); - let input_ordering_exprs = convert_to_expr(input_ordering); - let equal_properties = || input.equivalence_properties(); - let input_places = get_indices_of_matching_exprs( - &input_ordering_exprs, - partition_by_exprs, - equal_properties, - ); - let mut partition_places = get_indices_of_matching_exprs( - partition_by_exprs, - &input_ordering_exprs, - equal_properties, - ); - partition_places.sort(); - let first_n = longest_consecutive_prefix(partition_places); - input_places[0..first_n].to_vec() + let (_, indices) = input + .equivalence_properties() + .find_longest_permutation(partition_by_exprs); + indices +} + +pub(crate) fn get_partition_by_sort_exprs( + input: &Arc, + partition_by_exprs: &[Arc], + ordered_partition_by_indices: &[usize], +) -> Result { + let ordered_partition_exprs = ordered_partition_by_indices + .iter() + .map(|idx| partition_by_exprs[*idx].clone()) + .collect::>(); + // Make sure ordered section doesn't move over the partition by expression + assert!(ordered_partition_by_indices.len() <= partition_by_exprs.len()); + let (ordering, _) = input + .equivalence_properties() + .find_longest_permutation(&ordered_partition_exprs); + if ordering.len() == ordered_partition_exprs.len() { + Ok(ordering) + } else { + exec_err!("Expects PARTITION BY expression to be ordered") + } } -pub(crate) fn window_ordering_equivalence( +pub(crate) fn window_equivalence_properties( schema: &SchemaRef, input: &Arc, window_expr: &[Arc], -) -> OrderingEquivalenceProperties { +) -> EquivalenceProperties { // We need to update the schema, so we can not directly use - // `input.ordering_equivalence_properties()`. - let mut builder = OrderingEquivalenceBuilder::new(schema.clone()) - .with_equivalences(input.equivalence_properties()) - .with_existing_ordering(input.output_ordering().map(|elem| elem.to_vec())) - .extend(input.ordering_equivalence_properties()); + // `input.equivalence_properties()`. + let mut window_eq_properties = + EquivalenceProperties::new(schema.clone()).extend(input.equivalence_properties()); for expr in window_expr { if let Some(builtin_window_expr) = expr.as_any().downcast_ref::() { - builtin_window_expr - .add_equal_orderings(&mut builder, || input.equivalence_properties()); + builtin_window_expr.add_equal_orderings(&mut window_eq_properties); } } - builder.build() + window_eq_properties } /// Constructs the best-fitting windowing operator (a `WindowAggExec` or a @@ -384,17 +392,17 @@ pub fn get_best_fitting_window( // of the window_exprs are same. let partitionby_exprs = window_exprs[0].partition_by(); let orderby_keys = window_exprs[0].order_by(); - let (should_reverse, partition_search_mode) = - if let Some((should_reverse, partition_search_mode)) = - can_skip_sort(partitionby_exprs, orderby_keys, input)? + let (should_reverse, input_order_mode) = + if let Some((should_reverse, input_order_mode)) = + get_window_mode(partitionby_exprs, orderby_keys, input) { - (should_reverse, partition_search_mode) + (should_reverse, input_order_mode) } else { return Ok(None); }; let is_unbounded = unbounded_output(input); - if !is_unbounded && partition_search_mode != PartitionSearchMode::Sorted { - // Executor has bounded input and `partition_search_mode` is not `PartitionSearchMode::Sorted` + if !is_unbounded && input_order_mode != InputOrderMode::Sorted { + // Executor has bounded input and `input_order_mode` is not `InputOrderMode::Sorted` // in this case removing the sort is not helpful, return: return Ok(None); }; @@ -421,21 +429,19 @@ pub fn get_best_fitting_window( Ok(Some(Arc::new(BoundedWindowAggExec::try_new( window_expr, input.clone(), - input.schema(), physical_partition_keys.to_vec(), - partition_search_mode, + input_order_mode, )?) as _)) - } else if partition_search_mode != PartitionSearchMode::Sorted { + } else if input_order_mode != InputOrderMode::Sorted { // For `WindowAggExec` to work correctly PARTITION BY columns should be sorted. - // Hence, if `partition_search_mode` is not `PartitionSearchMode::Sorted` we should convert - // input ordering such that it can work with PartitionSearchMode::Sorted (add `SortExec`). - // Effectively `WindowAggExec` works only in PartitionSearchMode::Sorted mode. + // Hence, if `input_order_mode` is not `Sorted` we should convert + // input ordering such that it can work with `Sorted` (add `SortExec`). + // Effectively `WindowAggExec` works only in `Sorted` mode. Ok(None) } else { Ok(Some(Arc::new(WindowAggExec::try_new( window_expr, input.clone(), - input.schema(), physical_partition_keys.to_vec(), )?) as _)) } @@ -446,154 +452,46 @@ pub fn get_best_fitting_window( /// is sufficient to run the current window operator. /// - A `None` return value indicates that we can not remove the sort in question /// (input ordering is not sufficient to run current window executor). -/// - A `Some((bool, PartitionSearchMode))` value indicates that the window operator +/// - A `Some((bool, InputOrderMode))` value indicates that the window operator /// can run with existing input ordering, so we can remove `SortExec` before it. /// The `bool` field in the return value represents whether we should reverse window -/// operator to remove `SortExec` before it. The `PartitionSearchMode` field represents -/// the mode this window operator should work in to accomodate the existing ordering. -fn can_skip_sort( +/// operator to remove `SortExec` before it. The `InputOrderMode` field represents +/// the mode this window operator should work in to accommodate the existing ordering. +pub fn get_window_mode( partitionby_exprs: &[Arc], orderby_keys: &[PhysicalSortExpr], input: &Arc, -) -> Result> { - let physical_ordering = if let Some(physical_ordering) = input.output_ordering() { - physical_ordering - } else { - // If there is no physical ordering, there is no way to remove a - // sort, so immediately return. - return Ok(None); - }; - let orderby_exprs = convert_to_expr(orderby_keys); - let physical_ordering_exprs = convert_to_expr(physical_ordering); - let equal_properties = || input.equivalence_properties(); - // Get the indices of the ORDER BY expressions among input ordering expressions: - let ob_indices = get_indices_of_matching_exprs( - &orderby_exprs, - &physical_ordering_exprs, - equal_properties, - ); - if ob_indices.len() != orderby_exprs.len() { - // If all order by expressions are not in the input ordering, - // there is no way to remove a sort -- immediately return: - return Ok(None); - } - // Get the indices of the PARTITION BY expressions among input ordering expressions: - let pb_indices = get_indices_of_matching_exprs( - partitionby_exprs, - &physical_ordering_exprs, - equal_properties, - ); - let ordered_merged_indices = merge_and_order_indices(&pb_indices, &ob_indices); - // Get the indices of the ORDER BY columns that don't appear in the - // PARTITION BY clause; i.e. calculate (ORDER BY columns) ∖ (PARTITION - // BY columns) where `∖` represents set difference. - let unique_ob_indices = set_difference(&ob_indices, &pb_indices); - if !is_sorted(&unique_ob_indices) { - // ORDER BY indices should be ascending ordered - return Ok(None); - } - let first_n = longest_consecutive_prefix(ordered_merged_indices); - let furthest_ob_index = *unique_ob_indices.last().unwrap_or(&0); - // Cannot skip sort if last order by index is not within consecutive prefix. - // For instance, if input is ordered by a, b, c, d for the expression - // `PARTITION BY a, ORDER BY b, d`, then `first_n` would be 2 (meaning a, b defines a - // prefix for input ordering). However, `furthest_ob_index` would be 3 as column d - // occurs at the 3rd index of the existing ordering. Hence, existing ordering would - // not be sufficient to run the current operator. - // However, for expression `PARTITION BY a, ORDER BY b, c, d`, `first_n` would be 4 (meaning - // a, b, c, d defines a prefix for input ordering). Similarly, `furthest_ob_index` would be - // 3 as column d occurs at the 3rd index of the existing ordering. Therefore, the existing - // ordering would be sufficient to run the current operator. - if first_n <= furthest_ob_index { - return Ok(None); - } - let input_orderby_columns = get_at_indices(physical_ordering, &unique_ob_indices)?; - let expected_orderby_columns = - get_at_indices(orderby_keys, find_indices(&ob_indices, &unique_ob_indices)?)?; - let should_reverse = if let Some(should_reverse) = check_alignments( - &input.schema(), - &input_orderby_columns, - &expected_orderby_columns, - )? { - should_reverse - } else { - // If ordering directions are not aligned, we cannot calculate the - // result without changing existing ordering. - return Ok(None); - }; - - let ordered_pb_indices = pb_indices.iter().copied().sorted().collect::>(); - // Determine how many elements in the PARTITION BY columns defines a consecutive range from zero. - let first_n = longest_consecutive_prefix(&ordered_pb_indices); - let mode = if first_n == partitionby_exprs.len() { - // All of the PARTITION BY columns defines a consecutive range from zero. - PartitionSearchMode::Sorted - } else if first_n > 0 { - // All of the PARTITION BY columns defines a consecutive range from zero. - let ordered_range = &ordered_pb_indices[0..first_n]; - let input_pb_exprs = get_at_indices(&physical_ordering_exprs, ordered_range)?; - let partially_ordered_indices = get_indices_of_matching_exprs( - &input_pb_exprs, - partitionby_exprs, - equal_properties, - ); - PartitionSearchMode::PartiallySorted(partially_ordered_indices) - } else { - // None of the PARTITION BY columns defines a consecutive range from zero. - PartitionSearchMode::Linear - }; - - Ok(Some((should_reverse, mode))) -} - -/// Compares all the orderings in `physical_ordering` and `required`, decides -/// whether alignments match. A `None` return value indicates that current -/// column is not aligned. A `Some(bool)` value indicates otherwise, and signals -/// whether we should reverse the window expression in order to avoid sorting. -fn check_alignments( - schema: &SchemaRef, - physical_ordering: &[PhysicalSortExpr], - required: &[PhysicalSortExpr], -) -> Result> { - let result = izip!(physical_ordering, required) - .map(|(lhs, rhs)| check_alignment(schema, lhs, rhs)) - .collect::>>>()?; - Ok(if let Some(res) = result { - if !res.is_empty() { - let first = res[0]; - let all_same = res.into_iter().all(|elem| elem == first); - all_same.then_some(first) - } else { - Some(false) - } - } else { - // Cannot skip some of the requirements in the input. - None - }) -} - -/// Compares `physical_ordering` and `required` ordering, decides whether -/// alignments match. A `None` return value indicates that current column is -/// not aligned. A `Some(bool)` value indicates otherwise, and signals whether -/// we should reverse the window expression in order to avoid sorting. -fn check_alignment( - input_schema: &SchemaRef, - physical_ordering: &PhysicalSortExpr, - required: &PhysicalSortExpr, -) -> Result> { - Ok(if required.expr.eq(&physical_ordering.expr) { - let physical_opts = physical_ordering.options; - let required_opts = required.options; - if required.expr.nullable(input_schema)? { - let reverse = physical_opts == !required_opts; - (reverse || physical_opts == required_opts).then_some(reverse) - } else { - // If the column is not nullable, NULLS FIRST/LAST is not important. - Some(physical_opts.descending != required_opts.descending) +) -> Option<(bool, InputOrderMode)> { + let input_eqs = input.equivalence_properties(); + let mut partition_by_reqs: Vec = vec![]; + let (_, indices) = input_eqs.find_longest_permutation(partitionby_exprs); + partition_by_reqs.extend(indices.iter().map(|&idx| PhysicalSortRequirement { + expr: partitionby_exprs[idx].clone(), + options: None, + })); + // Treat partition by exprs as constant. During analysis of requirements are satisfied. + let partition_by_eqs = input_eqs.add_constants(partitionby_exprs.iter().cloned()); + let order_by_reqs = PhysicalSortRequirement::from_sort_exprs(orderby_keys); + let reverse_order_by_reqs = + PhysicalSortRequirement::from_sort_exprs(&reverse_order_bys(orderby_keys)); + for (should_swap, order_by_reqs) in + [(false, order_by_reqs), (true, reverse_order_by_reqs)] + { + let req = [partition_by_reqs.clone(), order_by_reqs].concat(); + let req = collapse_lex_req(req); + if partition_by_eqs.ordering_satisfy_requirement(&req) { + // Window can be run with existing ordering + let mode = if indices.len() == partitionby_exprs.len() { + InputOrderMode::Sorted + } else if indices.is_empty() { + InputOrderMode::Linear + } else { + InputOrderMode::PartiallySorted(indices) + }; + return Some((should_swap, mode)); } - } else { - None - }) + } + None } #[cfg(test)] @@ -605,7 +503,6 @@ mod tests { use crate::streaming::StreamingTableExec; use crate::test::assert_is_pending; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; - use crate::windows::PartitionSearchMode::{Linear, PartiallySorted, Sorted}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, SchemaRef}; @@ -613,6 +510,8 @@ mod tests { use futures::FutureExt; + use InputOrderMode::{Linear, PartiallySorted, Sorted}; + fn create_test_schema() -> Result { let nullable_column = Field::new("nullable_col", DataType::Int32, true); let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false); @@ -750,7 +649,7 @@ mod tests { let refs = blocking_exec.refs(); let window_agg_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), "count".to_owned(), &[col("a", &schema)?], &[], @@ -759,7 +658,6 @@ mod tests { schema.as_ref(), )?], blocking_exec, - schema, vec![], )?); @@ -774,15 +672,16 @@ mod tests { } #[tokio::test] - async fn test_is_column_aligned_nullable() -> Result<()> { + async fn test_satisfiy_nullable() -> Result<()> { let schema = create_test_schema()?; let params = vec![ - ((true, true), (false, false), Some(true)), - ((true, true), (false, true), None), - ((true, true), (true, false), None), - ((true, false), (false, true), Some(true)), - ((true, false), (false, false), None), - ((true, false), (true, true), None), + ((true, true), (false, false), false), + ((true, true), (false, true), false), + ((true, true), (true, false), false), + ((true, false), (false, true), false), + ((true, false), (false, false), false), + ((true, false), (true, true), false), + ((true, false), (true, false), true), ]; for ( (physical_desc, physical_nulls_first), @@ -804,7 +703,7 @@ mod tests { nulls_first: req_nulls_first, }, }; - let res = check_alignment(&schema, &physical_ordering, &required_ordering)?; + let res = physical_ordering.satisfy(&required_ordering.into(), &schema); assert_eq!(res, expected); } @@ -812,16 +711,17 @@ mod tests { } #[tokio::test] - async fn test_is_column_aligned_non_nullable() -> Result<()> { + async fn test_satisfy_non_nullable() -> Result<()> { let schema = create_test_schema()?; let params = vec![ - ((true, true), (false, false), Some(true)), - ((true, true), (false, true), Some(true)), - ((true, true), (true, false), Some(false)), - ((true, false), (false, true), Some(true)), - ((true, false), (false, false), Some(true)), - ((true, false), (true, true), Some(false)), + ((true, true), (false, false), false), + ((true, true), (false, true), false), + ((true, true), (true, false), true), + ((true, false), (false, true), false), + ((true, false), (false, false), false), + ((true, false), (true, true), true), + ((true, false), (true, false), true), ]; for ( (physical_desc, physical_nulls_first), @@ -843,7 +743,7 @@ mod tests { nulls_first: req_nulls_first, }, }; - let res = check_alignment(&schema, &physical_ordering, &required_ordering)?; + let res = physical_ordering.satisfy(&required_ordering.into(), &schema); assert_eq!(res, expected); } @@ -851,7 +751,7 @@ mod tests { } #[tokio::test] - async fn test_can_skip_ordering_exhaustive() -> Result<()> { + async fn test_get_window_mode_exhaustive() -> Result<()> { let test_schema = create_test_schema3()?; // Columns a,c are nullable whereas b,d are not nullable. // Source is sorted by a ASC NULLS FIRST, b ASC NULLS FIRST, c ASC NULLS FIRST, d ASC NULLS FIRST @@ -870,11 +770,11 @@ mod tests { // Second field in the tuple is Vec where each element in the vector represents ORDER BY columns // For instance, vec!["c"], corresponds to ORDER BY c ASC NULLS FIRST, (ordering is default ordering. We do not check // for reversibility in this test). - // Third field in the tuple is Option, which corresponds to expected algorithm mode. + // Third field in the tuple is Option, which corresponds to expected algorithm mode. // None represents that existing ordering is not sufficient to run executor with any one of the algorithms // (We need to add SortExec to be able to run it). - // Some(PartitionSearchMode) represents, we can run algorithm with existing ordering; and algorithm should work in - // PartitionSearchMode. + // Some(InputOrderMode) represents, we can run algorithm with existing ordering; and algorithm should work in + // InputOrderMode. let test_cases = vec![ (vec!["a"], vec!["a"], Some(Sorted)), (vec!["a"], vec!["b"], Some(Sorted)), @@ -884,7 +784,7 @@ mod tests { (vec!["a"], vec!["a", "c"], None), (vec!["a"], vec!["a", "b", "c"], Some(Sorted)), (vec!["b"], vec!["a"], Some(Linear)), - (vec!["b"], vec!["b"], None), + (vec!["b"], vec!["b"], Some(Linear)), (vec!["b"], vec!["c"], None), (vec!["b"], vec!["a", "b"], Some(Linear)), (vec!["b"], vec!["b", "c"], None), @@ -892,7 +792,7 @@ mod tests { (vec!["b"], vec!["a", "b", "c"], Some(Linear)), (vec!["c"], vec!["a"], Some(Linear)), (vec!["c"], vec!["b"], None), - (vec!["c"], vec!["c"], None), + (vec!["c"], vec!["c"], Some(Linear)), (vec!["c"], vec!["a", "b"], Some(Linear)), (vec!["c"], vec!["b", "c"], None), (vec!["c"], vec!["a", "c"], Some(Linear)), @@ -905,10 +805,10 @@ mod tests { (vec!["b", "a"], vec!["a", "c"], Some(Sorted)), (vec!["b", "a"], vec!["a", "b", "c"], Some(Sorted)), (vec!["c", "b"], vec!["a"], Some(Linear)), - (vec!["c", "b"], vec!["b"], None), - (vec!["c", "b"], vec!["c"], None), + (vec!["c", "b"], vec!["b"], Some(Linear)), + (vec!["c", "b"], vec!["c"], Some(Linear)), (vec!["c", "b"], vec!["a", "b"], Some(Linear)), - (vec!["c", "b"], vec!["b", "c"], None), + (vec!["c", "b"], vec!["b", "c"], Some(Linear)), (vec!["c", "b"], vec!["a", "c"], Some(Linear)), (vec!["c", "b"], vec!["a", "b", "c"], Some(Linear)), (vec!["c", "a"], vec!["a"], Some(PartiallySorted(vec![1]))), @@ -958,8 +858,8 @@ mod tests { order_by_exprs.push(PhysicalSortExpr { expr, options }); } let res = - can_skip_sort(&partition_by_exprs, &order_by_exprs, &exec_unbounded)?; - // Since reversibility is not important in this test. Convert Option<(bool, PartitionSearchMode)> to Option + get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded); + // Since reversibility is not important in this test. Convert Option<(bool, InputOrderMode)> to Option let res = res.map(|(_, mode)| mode); assert_eq!( res, *expected, @@ -971,7 +871,7 @@ mod tests { } #[tokio::test] - async fn test_can_skip_ordering() -> Result<()> { + async fn test_get_window_mode() -> Result<()> { let test_schema = create_test_schema3()?; // Columns a,c are nullable whereas b,d are not nullable. // Source is sorted by a ASC NULLS FIRST, b ASC NULLS FIRST, c ASC NULLS FIRST, d ASC NULLS FIRST @@ -990,12 +890,12 @@ mod tests { // Second field in the tuple is Vec<(str, bool, bool)> where each element in the vector represents ORDER BY columns // For instance, vec![("c", false, false)], corresponds to ORDER BY c ASC NULLS LAST, // similarly, vec![("c", true, true)], corresponds to ORDER BY c DESC NULLS FIRST, - // Third field in the tuple is Option<(bool, PartitionSearchMode)>, which corresponds to expected result. + // Third field in the tuple is Option<(bool, InputOrderMode)>, which corresponds to expected result. // None represents that existing ordering is not sufficient to run executor with any one of the algorithms // (We need to add SortExec to be able to run it). - // Some((bool, PartitionSearchMode)) represents, we can run algorithm with existing ordering. Algorithm should work in - // PartitionSearchMode, bool field represents whether we should reverse window expressions to run executor with existing ordering. - // For instance, `Some((false, PartitionSearchMode::Sorted))`, represents that we shouldn't reverse window expressions. And algorithm + // Some((bool, InputOrderMode)) represents, we can run algorithm with existing ordering. Algorithm should work in + // InputOrderMode, bool field represents whether we should reverse window expressions to run executor with existing ordering. + // For instance, `Some((false, InputOrderMode::Sorted))`, represents that we shouldn't reverse window expressions. And algorithm // should work in Sorted mode to work with existing ordering. let test_cases = vec![ // PARTITION BY a, b ORDER BY c ASC NULLS LAST @@ -1122,7 +1022,7 @@ mod tests { } assert_eq!( - can_skip_sort(&partition_by_exprs, &order_by_exprs, &exec_unbounded)?, + get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded), *expected, "Unexpected result for in unbounded test case#: {case_idx:?}, case: {test_case:?}" ); diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index b56a9c194c8f..6c245f65ba4f 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -26,12 +26,13 @@ use crate::common::transpose; use crate::expressions::PhysicalSortExpr; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::windows::{ - calc_requirements, get_ordered_partition_by_indices, window_ordering_equivalence, + calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs, + window_equivalence_properties, }; use crate::{ - ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, EquivalenceProperties, - ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, - SendableRecordBatchStream, Statistics, WindowExpr, + ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, + Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, + WindowExpr, }; use arrow::compute::{concat, concat_batches}; @@ -42,10 +43,11 @@ use arrow::{ datatypes::{Schema, SchemaRef}, record_batch::RecordBatch, }; -use datafusion_common::utils::{evaluate_partition_ranges, get_at_indices}; +use datafusion_common::stats::Precision; +use datafusion_common::utils::evaluate_partition_ranges; use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{OrderingEquivalenceProperties, PhysicalSortRequirement}; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement}; use futures::stream::Stream; use futures::{ready, StreamExt}; @@ -59,8 +61,6 @@ pub struct WindowAggExec { window_expr: Vec>, /// Schema after the window is run schema: SchemaRef, - /// Schema before the window - input_schema: SchemaRef, /// Partition Keys pub partition_keys: Vec>, /// Execution metrics @@ -75,10 +75,9 @@ impl WindowAggExec { pub fn try_new( window_expr: Vec>, input: Arc, - input_schema: SchemaRef, partition_keys: Vec>, ) -> Result { - let schema = create_schema(&input_schema, &window_expr)?; + let schema = create_schema(&input.schema(), &window_expr)?; let schema = Arc::new(schema); let ordered_partition_by_indices = @@ -87,7 +86,6 @@ impl WindowAggExec { input, window_expr, schema, - input_schema, partition_keys, metrics: ExecutionPlanMetricsSet::new(), ordered_partition_by_indices, @@ -104,20 +102,18 @@ impl WindowAggExec { &self.input } - /// Get the input schema before any window functions are applied - pub fn input_schema(&self) -> SchemaRef { - self.input_schema.clone() - } - /// Return the output sort order of partition keys: For example /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a // We are sure that partition by columns are always at the beginning of sort_keys // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely // to calculate partition separation points pub fn partition_by_sort_keys(&self) -> Result> { - // Partition by sort keys indices are stored in self.ordered_partition_by_indices. - let sort_keys = self.input.output_ordering().unwrap_or(&[]); - get_at_indices(sort_keys, &self.ordered_partition_by_indices) + let partition_by = self.window_expr()[0].partition_by(); + get_partition_by_sort_exprs( + &self.input, + partition_by, + &self.ordered_partition_by_indices, + ) } } @@ -214,13 +210,9 @@ impl ExecutionPlan for WindowAggExec { } } + /// Get the [`EquivalenceProperties`] within the plan fn equivalence_properties(&self) -> EquivalenceProperties { - self.input().equivalence_properties() - } - - /// Get the OrderingEquivalenceProperties within the plan - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - window_ordering_equivalence(&self.schema, &self.input, &self.window_expr) + window_equivalence_properties(&self.schema, &self.input, &self.window_expr) } fn with_new_children( @@ -230,7 +222,6 @@ impl ExecutionPlan for WindowAggExec { Ok(Arc::new(WindowAggExec::try_new( self.window_expr.clone(), children[0].clone(), - self.input_schema.clone(), self.partition_keys.clone(), )?)) } @@ -256,24 +247,22 @@ impl ExecutionPlan for WindowAggExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - let input_stat = self.input.statistics(); + fn statistics(&self) -> Result { + let input_stat = self.input.statistics()?; let win_cols = self.window_expr.len(); - let input_cols = self.input_schema.fields().len(); + let input_cols = self.input.schema().fields().len(); // TODO stats: some windowing function will maintain invariants such as min, max... let mut column_statistics = Vec::with_capacity(win_cols + input_cols); - if let Some(input_col_stats) = input_stat.column_statistics { - column_statistics.extend(input_col_stats); - } else { - column_statistics.extend(vec![ColumnStatistics::default(); input_cols]); + // copy stats of the input to the beginning of the schema. + column_statistics.extend(input_stat.column_statistics); + for _ in 0..win_cols { + column_statistics.push(ColumnStatistics::new_unknown()) } - column_statistics.extend(vec![ColumnStatistics::default(); win_cols]); - Statistics { - is_exact: input_stat.is_exact, + Ok(Statistics { num_rows: input_stat.num_rows, - column_statistics: Some(column_statistics), - total_byte_size: None, - } + column_statistics, + total_byte_size: Precision::Absent, + }) } } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index e27eff537ba6..f9f24b28db81 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -19,9 +19,9 @@ name = "datafusion-proto" description = "Protobuf serialization of DataFusion logical plan expressions" keywords = ["arrow", "query", "sql"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -36,21 +36,22 @@ name = "datafusion_proto" path = "src/lib.rs" [features] -default = [] +default = ["parquet"] json = ["pbjson", "serde", "serde_json"] +parquet = ["datafusion/parquet", "datafusion-common/parquet"] [dependencies] arrow = { workspace = true } chrono = { workspace = true } -datafusion = { path = "../core", version = "31.0.0" } -datafusion-common = { path = "../common", version = "31.0.0", default-features = false } -datafusion-expr = { path = "../expr", version = "31.0.0" } -object_store = { version = "0.7.0" } -pbjson = { version = "0.5", optional = true } +datafusion = { path = "../core", version = "34.0.0" } +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } +object_store = { workspace = true } +pbjson = { version = "0.6.0", optional = true } prost = "0.12.0" serde = { version = "1.0", optional = true } -serde_json = { version = "1.0", optional = true } +serde_json = { workspace = true, optional = true } [dev-dependencies] -doc-comment = "0.3" +doc-comment = { workspace = true } tokio = "1.18" diff --git a/datafusion/proto/README.md b/datafusion/proto/README.md index fd66d54aa2de..171aadb744d6 100644 --- a/datafusion/proto/README.md +++ b/datafusion/proto/README.md @@ -19,7 +19,7 @@ # DataFusion Proto -[DataFusion](df) is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. This crate is a submodule of DataFusion that provides a protocol buffer format for representing query plans and expressions. diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index 37c49666d3d7..8b3f3f98a8a1 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -32,4 +32,4 @@ publish = false [dependencies] # Pin these dependencies so that the generated output is deterministic pbjson-build = "=0.6.2" -prost-build = "=0.12.1" +prost-build = "=0.12.3" diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 0ebcf2537dda..d5f8397aa30c 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -73,6 +73,8 @@ message LogicalPlanNode { CustomTableScanNode custom_scan = 25; PrepareNode prepare = 26; DropViewNode drop_view = 27; + DistinctOnNode distinct_on = 28; + CopyToNode copy_to = 29; } } @@ -180,6 +182,25 @@ message EmptyRelationNode { bool produce_one_row = 1; } +message PrimaryKeyConstraint{ + repeated uint64 indices = 1; +} + +message UniqueConstraint{ + repeated uint64 indices = 1; +} + +message Constraint{ + oneof constraint_mode{ + PrimaryKeyConstraint primary_key = 1; + UniqueConstraint unique = 2; + } +} + +message Constraints{ + repeated Constraint constraints = 1; +} + message CreateExternalTableNode { reserved 1; // was string name OwnedTableReference name = 12; @@ -195,6 +216,8 @@ message CreateExternalTableNode { repeated LogicalExprNodeCollection order_exprs = 13; bool unbounded = 14; map options = 11; + Constraints constraints = 15; + map column_defaults = 16; } message PrepareNode { @@ -288,6 +311,33 @@ message DistinctNode { LogicalPlanNode input = 1; } +message DistinctOnNode { + repeated LogicalExprNode on_expr = 1; + repeated LogicalExprNode select_expr = 2; + repeated LogicalExprNode sort_expr = 3; + LogicalPlanNode input = 4; +} + +message CopyToNode { + LogicalPlanNode input = 1; + string output_url = 2; + bool single_file_output = 3; + oneof CopyOptions { + SQLOptions sql_options = 4; + FileTypeWriterOptions writer_options = 5; + } + string file_type = 6; +} + +message SQLOptions { + repeated SQLOption option = 1; +} + +message SQLOption { + string key = 1; + string value = 2; +} + message UnionNode { repeated LogicalPlanNode inputs = 1; } @@ -343,7 +393,7 @@ message LogicalExprNode { SortExprNode sort = 12; NegativeNode negative = 13; InListNode in_list = 14; - bool wildcard = 15; + Wildcard wildcard = 15; ScalarFunctionNode scalar_function = 16; TryCastNode try_cast = 17; @@ -379,6 +429,10 @@ message LogicalExprNode { } } +message Wildcard { + string qualifier = 1; +} + message PlaceholderNode { string id = 1; ArrowType data_type = 2; @@ -461,6 +515,7 @@ message Not { message AliasNode { LogicalExprNode expr = 1; string alias = 2; + repeated OwnedTableReference relation = 3; } message BinaryExprNode { @@ -600,6 +655,18 @@ enum ScalarFunction { ArrayEmpty = 115; ArrayPopBack = 116; StringToArray = 117; + ToTimestampNanos = 118; + ArrayIntersect = 119; + ArrayUnion = 120; + OverLay = 121; + Range = 122; + ArrayExcept = 123; + ArrayPopFront = 124; + Levenshtein = 125; + SubstrIndex = 126; + FindInSet = 127; + ArraySort = 128; + ArrayDistinct = 129; } message ScalarFunctionNode { @@ -645,6 +712,7 @@ enum AggregateFunction { REGR_SXX = 32; REGR_SYY = 33; REGR_SXY = 34; + STRING_AGG = 35; } message AggregateExprNode { @@ -794,6 +862,8 @@ message Field { // for complex data types like structs, unions repeated Field children = 4; map metadata = 5; + int64 dict_id = 6; + bool dict_ordered = 7; } message FixedSizeBinary{ @@ -863,12 +933,10 @@ message Union{ repeated int32 type_ids = 3; } -message ScalarListValue{ - // encode null explicitly to distinguish a list with a null value - // from a list with no values) - bool is_null = 3; - Field field = 1; - repeated ScalarValue values = 2; +message ScalarListValue { + bytes ipc_message = 1; + bytes arrow_data = 2; + Schema schema = 3; } message ScalarTime32Value { @@ -944,8 +1012,9 @@ message ScalarValue{ // Literal Date32 value always has a unit of day int32 date_32_value = 14; ScalarTime32Value time32_value = 15; + ScalarListValue large_list_value = 16; ScalarListValue list_value = 17; - //WAS: ScalarType null_list_value = 18; + ScalarListValue fixed_size_list_value = 18; Decimal128 decimal128_value = 20; Decimal256 decimal256_value = 39; @@ -1052,8 +1121,10 @@ message PlanType { OptimizedLogicalPlanType OptimizedLogicalPlan = 2; EmptyMessage FinalLogicalPlan = 3; EmptyMessage InitialPhysicalPlan = 4; + EmptyMessage InitialPhysicalPlanWithStats = 9; OptimizedPhysicalPlanType OptimizedPhysicalPlan = 5; EmptyMessage FinalPhysicalPlan = 6; + EmptyMessage FinalPhysicalPlanWithStats = 10; } } @@ -1112,9 +1183,119 @@ message PhysicalPlanNode { SortPreservingMergeExecNode sort_preserving_merge = 21; NestedLoopJoinExecNode nested_loop_join = 22; AnalyzeExecNode analyze = 23; + JsonSinkExecNode json_sink = 24; + SymmetricHashJoinExecNode symmetric_hash_join = 25; + InterleaveExecNode interleave = 26; + PlaceholderRowExecNode placeholder_row = 27; + CsvSinkExecNode csv_sink = 28; + ParquetSinkExecNode parquet_sink = 29; + } +} + +enum CompressionTypeVariant { + GZIP = 0; + BZIP2 = 1; + XZ = 2; + ZSTD = 3; + UNCOMPRESSED = 4; +} + +message PartitionColumn { + string name = 1; + ArrowType arrow_type = 2; +} + +message FileTypeWriterOptions { + oneof FileType { + JsonWriterOptions json_options = 1; + ParquetWriterOptions parquet_options = 2; + CsvWriterOptions csv_options = 3; } } +message JsonWriterOptions { + CompressionTypeVariant compression = 1; +} + +message ParquetWriterOptions { + WriterProperties writer_properties = 1; +} + +message CsvWriterOptions { + // Compression type + CompressionTypeVariant compression = 1; + // Optional column delimiter. Defaults to `b','` + string delimiter = 2; + // Whether to write column names as file headers. Defaults to `true` + bool has_header = 3; + // Optional date format for date arrays + string date_format = 4; + // Optional datetime format for datetime arrays + string datetime_format = 5; + // Optional timestamp format for timestamp arrays + string timestamp_format = 6; + // Optional time format for time arrays + string time_format = 7; + // Optional value to represent null + string null_value = 8; +} + +message WriterProperties { + uint64 data_page_size_limit = 1; + uint64 dictionary_page_size_limit = 2; + uint64 data_page_row_count_limit = 3; + uint64 write_batch_size = 4; + uint64 max_row_group_size = 5; + string writer_version = 6; + string created_by = 7; +} + +message FileSinkConfig { + reserved 6; // writer_mode + + string object_store_url = 1; + repeated PartitionedFile file_groups = 2; + repeated string table_paths = 3; + Schema output_schema = 4; + repeated PartitionColumn table_partition_cols = 5; + bool single_file_output = 7; + bool overwrite = 8; + FileTypeWriterOptions file_type_writer_options = 9; +} + +message JsonSink { + FileSinkConfig config = 1; +} + +message JsonSinkExecNode { + PhysicalPlanNode input = 1; + JsonSink sink = 2; + Schema sink_schema = 3; + PhysicalSortExprNodeCollection sort_order = 4; +} + +message CsvSink { + FileSinkConfig config = 1; +} + +message CsvSinkExecNode { + PhysicalPlanNode input = 1; + CsvSink sink = 2; + Schema sink_schema = 3; + PhysicalSortExprNodeCollection sort_order = 4; +} + +message ParquetSink { + FileSinkConfig config = 1; +} + +message ParquetSinkExecNode { + PhysicalPlanNode input = 1; + ParquetSink sink = 2; + Schema sink_schema = 3; + PhysicalSortExprNodeCollection sort_order = 4; +} + message PhysicalExtensionNode { bytes node = 1; repeated PhysicalPlanNode inputs = 2; @@ -1273,6 +1454,7 @@ message PhysicalNegativeNode { message FilterExecNode { PhysicalPlanNode input = 1; PhysicalExprNode expr = 2; + uint32 default_filter_selectivity = 3; } message FileGroup { @@ -1341,6 +1523,25 @@ message HashJoinExecNode { JoinFilter filter = 8; } +enum StreamPartitionMode { + SINGLE_PARTITION = 0; + PARTITIONED_EXEC = 1; +} + +message SymmetricHashJoinExecNode { + PhysicalPlanNode left = 1; + PhysicalPlanNode right = 2; + repeated JoinOn on = 3; + JoinType join_type = 4; + StreamPartitionMode partition_mode = 6; + bool null_equals_null = 7; + JoinFilter filter = 8; +} + +message InterleaveExecNode { + repeated PhysicalPlanNode inputs = 1; +} + message UnionExecNode { repeated PhysicalPlanNode inputs = 1; } @@ -1374,8 +1575,11 @@ message JoinOn { } message EmptyExecNode { - bool produce_one_row = 1; - Schema schema = 2; + Schema schema = 1; +} + +message PlaceholderRowExecNode { + Schema schema = 1; } message ProjectionExecNode { @@ -1392,19 +1596,18 @@ enum AggregateMode { SINGLE_PARTITIONED = 4; } -message PartiallySortedPartitionSearchMode { +message PartiallySortedInputOrderMode { repeated uint64 columns = 6; } message WindowAggExecNode { PhysicalPlanNode input = 1; repeated PhysicalWindowExprNode window_expr = 2; - Schema input_schema = 4; repeated PhysicalExprNode partition_keys = 5; // Set optional to `None` for `BoundedWindowAggExec`. - oneof partition_search_mode { + oneof input_order_mode { EmptyMessage linear = 7; - PartiallySortedPartitionSearchMode partially_sorted = 8; + PartiallySortedInputOrderMode partially_sorted = 8; EmptyMessage sorted = 9; } } @@ -1429,7 +1632,6 @@ message AggregateExecNode { repeated PhysicalExprNode null_expr = 8; repeated bool groups = 9; repeated MaybeFilter filter_expr = 10; - repeated MaybePhysicalSortExprs order_by_expr = 11; } message GlobalLimitExecNode { @@ -1526,18 +1728,28 @@ message PartitionStats { repeated ColumnStats column_stats = 4; } +message Precision{ + PrecisionInfo precision_info = 1; + ScalarValue val = 2; +} + +enum PrecisionInfo { + EXACT = 0; + INEXACT = 1; + ABSENT = 2; +} + message Statistics { - int64 num_rows = 1; - int64 total_byte_size = 2; + Precision num_rows = 1; + Precision total_byte_size = 2; repeated ColumnStats column_stats = 3; - bool is_exact = 4; } message ColumnStats { - ScalarValue min_value = 1; - ScalarValue max_value = 2; - uint32 null_count = 3; - uint32 distinct_count = 4; + Precision min_value = 1; + Precision max_value = 2; + Precision null_count = 3; + Precision distinct_count = 4; } message NamedStructFieldExpr { diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index d3ac33bab535..9377501499e2 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -24,7 +24,7 @@ use crate::physical_plan::{ }; use crate::protobuf; use datafusion::physical_plan::functions::make_scalar_function; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; use datafusion_expr::{ create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LogicalPlan, Volatility, WindowUDF, @@ -88,13 +88,13 @@ pub trait Serializeable: Sized { impl Serializeable for Expr { fn to_bytes(&self) -> Result { let mut buffer = BytesMut::new(); - let protobuf: protobuf::LogicalExprNode = self.try_into().map_err(|e| { - DataFusionError::Plan(format!("Error encoding expr as protobuf: {e}")) - })?; + let protobuf: protobuf::LogicalExprNode = self + .try_into() + .map_err(|e| plan_datafusion_err!("Error encoding expr as protobuf: {e}"))?; - protobuf.encode(&mut buffer).map_err(|e| { - DataFusionError::Plan(format!("Error encoding protobuf as bytes: {e}")) - })?; + protobuf + .encode(&mut buffer) + .map_err(|e| plan_datafusion_err!("Error encoding protobuf as bytes: {e}"))?; let bytes: Bytes = buffer.into(); @@ -151,13 +151,11 @@ impl Serializeable for Expr { bytes: &[u8], registry: &dyn FunctionRegistry, ) -> Result { - let protobuf = protobuf::LogicalExprNode::decode(bytes).map_err(|e| { - DataFusionError::Plan(format!("Error decoding expr as protobuf: {e}")) - })?; + let protobuf = protobuf::LogicalExprNode::decode(bytes) + .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; - logical_plan::from_proto::parse_expr(&protobuf, registry).map_err(|e| { - DataFusionError::Plan(format!("Error parsing protobuf into Expr: {e}")) - }) + logical_plan::from_proto::parse_expr(&protobuf, registry) + .map_err(|e| plan_datafusion_err!("Error parsing protobuf into Expr: {e}")) } } @@ -173,9 +171,9 @@ pub fn logical_plan_to_json(plan: &LogicalPlan) -> Result { let extension_codec = DefaultLogicalExtensionCodec {}; let protobuf = protobuf::LogicalPlanNode::try_from_logical_plan(plan, &extension_codec) - .map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}")))?; + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; serde_json::to_string(&protobuf) - .map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}"))) + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}")) } /// Serialize a LogicalPlan as bytes, using the provided extension codec @@ -186,9 +184,9 @@ pub fn logical_plan_to_bytes_with_extension_codec( let protobuf = protobuf::LogicalPlanNode::try_from_logical_plan(plan, extension_codec)?; let mut buffer = BytesMut::new(); - protobuf.encode(&mut buffer).map_err(|e| { - DataFusionError::Plan(format!("Error encoding protobuf as bytes: {e}")) - })?; + protobuf + .encode(&mut buffer) + .map_err(|e| plan_datafusion_err!("Error encoding protobuf as bytes: {e}"))?; Ok(buffer.into()) } @@ -196,7 +194,7 @@ pub fn logical_plan_to_bytes_with_extension_codec( #[cfg(feature = "json")] pub fn logical_plan_from_json(json: &str, ctx: &SessionContext) -> Result { let back: protobuf::LogicalPlanNode = serde_json::from_str(json) - .map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}")))?; + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; let extension_codec = DefaultLogicalExtensionCodec {}; back.try_into_logical_plan(ctx, &extension_codec) } @@ -216,9 +214,8 @@ pub fn logical_plan_from_bytes_with_extension_codec( ctx: &SessionContext, extension_codec: &dyn LogicalExtensionCodec, ) -> Result { - let protobuf = protobuf::LogicalPlanNode::decode(bytes).map_err(|e| { - DataFusionError::Plan(format!("Error decoding expr as protobuf: {e}")) - })?; + let protobuf = protobuf::LogicalPlanNode::decode(bytes) + .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; protobuf.try_into_logical_plan(ctx, extension_codec) } @@ -234,9 +231,9 @@ pub fn physical_plan_to_json(plan: Arc) -> Result { let extension_codec = DefaultPhysicalExtensionCodec {}; let protobuf = protobuf::PhysicalPlanNode::try_from_physical_plan(plan, &extension_codec) - .map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}")))?; + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; serde_json::to_string(&protobuf) - .map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}"))) + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}")) } /// Serialize a PhysicalPlan as bytes, using the provided extension codec @@ -247,9 +244,9 @@ pub fn physical_plan_to_bytes_with_extension_codec( let protobuf = protobuf::PhysicalPlanNode::try_from_physical_plan(plan, extension_codec)?; let mut buffer = BytesMut::new(); - protobuf.encode(&mut buffer).map_err(|e| { - DataFusionError::Plan(format!("Error encoding protobuf as bytes: {e}")) - })?; + protobuf + .encode(&mut buffer) + .map_err(|e| plan_datafusion_err!("Error encoding protobuf as bytes: {e}"))?; Ok(buffer.into()) } @@ -260,7 +257,7 @@ pub fn physical_plan_from_json( ctx: &SessionContext, ) -> Result> { let back: protobuf::PhysicalPlanNode = serde_json::from_str(json) - .map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}")))?; + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; let extension_codec = DefaultPhysicalExtensionCodec {}; back.try_into_physical_plan(ctx, &ctx.runtime_env(), &extension_codec) } @@ -280,8 +277,7 @@ pub fn physical_plan_from_bytes_with_extension_codec( ctx: &SessionContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let protobuf = protobuf::PhysicalPlanNode::decode(bytes).map_err(|e| { - DataFusionError::Plan(format!("Error decoding expr as protobuf: {e}")) - })?; + let protobuf = protobuf::PhysicalPlanNode::decode(bytes) + .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; protobuf.try_into_physical_plan(ctx, &ctx.runtime_env(), extension_codec) } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index d1e9e886e7d5..12e834d75adf 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -36,9 +36,6 @@ impl serde::Serialize for AggregateExecNode { if !self.filter_expr.is_empty() { len += 1; } - if !self.order_by_expr.is_empty() { - len += 1; - } let mut struct_ser = serializer.serialize_struct("datafusion.AggregateExecNode", len)?; if !self.group_expr.is_empty() { struct_ser.serialize_field("groupExpr", &self.group_expr)?; @@ -72,9 +69,6 @@ impl serde::Serialize for AggregateExecNode { if !self.filter_expr.is_empty() { struct_ser.serialize_field("filterExpr", &self.filter_expr)?; } - if !self.order_by_expr.is_empty() { - struct_ser.serialize_field("orderByExpr", &self.order_by_expr)?; - } struct_ser.end() } } @@ -102,8 +96,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "groups", "filter_expr", "filterExpr", - "order_by_expr", - "orderByExpr", ]; #[allow(clippy::enum_variant_names)] @@ -118,7 +110,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { NullExpr, Groups, FilterExpr, - OrderByExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -150,7 +141,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "nullExpr" | "null_expr" => Ok(GeneratedField::NullExpr), "groups" => Ok(GeneratedField::Groups), "filterExpr" | "filter_expr" => Ok(GeneratedField::FilterExpr), - "orderByExpr" | "order_by_expr" => Ok(GeneratedField::OrderByExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -180,7 +170,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { let mut null_expr__ = None; let mut groups__ = None; let mut filter_expr__ = None; - let mut order_by_expr__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::GroupExpr => { @@ -243,12 +232,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { } filter_expr__ = Some(map_.next_value()?); } - GeneratedField::OrderByExpr => { - if order_by_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("orderByExpr")); - } - order_by_expr__ = Some(map_.next_value()?); - } } } Ok(AggregateExecNode { @@ -262,7 +245,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { null_expr: null_expr__.unwrap_or_default(), groups: groups__.unwrap_or_default(), filter_expr: filter_expr__.unwrap_or_default(), - order_by_expr: order_by_expr__.unwrap_or_default(), }) } } @@ -474,6 +456,7 @@ impl serde::Serialize for AggregateFunction { Self::RegrSxx => "REGR_SXX", Self::RegrSyy => "REGR_SYY", Self::RegrSxy => "REGR_SXY", + Self::StringAgg => "STRING_AGG", }; serializer.serialize_str(variant) } @@ -520,6 +503,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SXX", "REGR_SYY", "REGR_SXY", + "STRING_AGG", ]; struct GeneratedVisitor; @@ -595,6 +579,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SXX" => Ok(AggregateFunction::RegrSxx), "REGR_SYY" => Ok(AggregateFunction::RegrSyy), "REGR_SXY" => Ok(AggregateFunction::RegrSxy), + "STRING_AGG" => Ok(AggregateFunction::StringAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -967,6 +952,9 @@ impl serde::Serialize for AliasNode { if !self.alias.is_empty() { len += 1; } + if !self.relation.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AliasNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; @@ -974,6 +962,9 @@ impl serde::Serialize for AliasNode { if !self.alias.is_empty() { struct_ser.serialize_field("alias", &self.alias)?; } + if !self.relation.is_empty() { + struct_ser.serialize_field("relation", &self.relation)?; + } struct_ser.end() } } @@ -986,12 +977,14 @@ impl<'de> serde::Deserialize<'de> for AliasNode { const FIELDS: &[&str] = &[ "expr", "alias", + "relation", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, Alias, + Relation, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -1015,6 +1008,7 @@ impl<'de> serde::Deserialize<'de> for AliasNode { match value { "expr" => Ok(GeneratedField::Expr), "alias" => Ok(GeneratedField::Alias), + "relation" => Ok(GeneratedField::Relation), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -1036,6 +1030,7 @@ impl<'de> serde::Deserialize<'de> for AliasNode { { let mut expr__ = None; let mut alias__ = None; + let mut relation__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -1050,11 +1045,18 @@ impl<'de> serde::Deserialize<'de> for AliasNode { } alias__ = Some(map_.next_value()?); } + GeneratedField::Relation => { + if relation__.is_some() { + return Err(serde::de::Error::duplicate_field("relation")); + } + relation__ = Some(map_.next_value()?); + } } } Ok(AliasNode { expr: expr__, alias: alias__.unwrap_or_default(), + relation: relation__.unwrap_or_default(), }) } } @@ -3289,10 +3291,10 @@ impl serde::Serialize for ColumnStats { if self.max_value.is_some() { len += 1; } - if self.null_count != 0 { + if self.null_count.is_some() { len += 1; } - if self.distinct_count != 0 { + if self.distinct_count.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.ColumnStats", len)?; @@ -3302,11 +3304,11 @@ impl serde::Serialize for ColumnStats { if let Some(v) = self.max_value.as_ref() { struct_ser.serialize_field("maxValue", v)?; } - if self.null_count != 0 { - struct_ser.serialize_field("nullCount", &self.null_count)?; + if let Some(v) = self.null_count.as_ref() { + struct_ser.serialize_field("nullCount", v)?; } - if self.distinct_count != 0 { - struct_ser.serialize_field("distinctCount", &self.distinct_count)?; + if let Some(v) = self.distinct_count.as_ref() { + struct_ser.serialize_field("distinctCount", v)?; } struct_ser.end() } @@ -3400,32 +3402,108 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { if null_count__.is_some() { return Err(serde::de::Error::duplicate_field("nullCount")); } - null_count__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + null_count__ = map_.next_value()?; } GeneratedField::DistinctCount => { if distinct_count__.is_some() { return Err(serde::de::Error::duplicate_field("distinctCount")); } - distinct_count__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + distinct_count__ = map_.next_value()?; } } } Ok(ColumnStats { min_value: min_value__, max_value: max_value__, - null_count: null_count__.unwrap_or_default(), - distinct_count: distinct_count__.unwrap_or_default(), + null_count: null_count__, + distinct_count: distinct_count__, }) } } deserializer.deserialize_struct("datafusion.ColumnStats", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CreateCatalogNode { +impl serde::Serialize for CompressionTypeVariant { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Gzip => "GZIP", + Self::Bzip2 => "BZIP2", + Self::Xz => "XZ", + Self::Zstd => "ZSTD", + Self::Uncompressed => "UNCOMPRESSED", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for CompressionTypeVariant { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "GZIP", + "BZIP2", + "XZ", + "ZSTD", + "UNCOMPRESSED", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CompressionTypeVariant; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "GZIP" => Ok(CompressionTypeVariant::Gzip), + "BZIP2" => Ok(CompressionTypeVariant::Bzip2), + "XZ" => Ok(CompressionTypeVariant::Xz), + "ZSTD" => Ok(CompressionTypeVariant::Zstd), + "UNCOMPRESSED" => Ok(CompressionTypeVariant::Uncompressed), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for Constraint { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -3433,47 +3511,39 @@ impl serde::Serialize for CreateCatalogNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.catalog_name.is_empty() { - len += 1; - } - if self.if_not_exists { - len += 1; - } - if self.schema.is_some() { + if self.constraint_mode.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CreateCatalogNode", len)?; - if !self.catalog_name.is_empty() { - struct_ser.serialize_field("catalogName", &self.catalog_name)?; - } - if self.if_not_exists { - struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; - } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.Constraint", len)?; + if let Some(v) = self.constraint_mode.as_ref() { + match v { + constraint::ConstraintMode::PrimaryKey(v) => { + struct_ser.serialize_field("primaryKey", v)?; + } + constraint::ConstraintMode::Unique(v) => { + struct_ser.serialize_field("unique", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CreateCatalogNode { +impl<'de> serde::Deserialize<'de> for Constraint { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "catalog_name", - "catalogName", - "if_not_exists", - "ifNotExists", - "schema", + "primary_key", + "primaryKey", + "unique", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - CatalogName, - IfNotExists, - Schema, + PrimaryKey, + Unique, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -3495,9 +3565,8 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogNode { E: serde::de::Error, { match value { - "catalogName" | "catalog_name" => Ok(GeneratedField::CatalogName), - "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), - "schema" => Ok(GeneratedField::Schema), + "primaryKey" | "primary_key" => Ok(GeneratedField::PrimaryKey), + "unique" => Ok(GeneratedField::Unique), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -3507,52 +3576,44 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CreateCatalogNode; + type Value = Constraint; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CreateCatalogNode") + formatter.write_str("struct datafusion.Constraint") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut catalog_name__ = None; - let mut if_not_exists__ = None; - let mut schema__ = None; + let mut constraint_mode__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::CatalogName => { - if catalog_name__.is_some() { - return Err(serde::de::Error::duplicate_field("catalogName")); - } - catalog_name__ = Some(map_.next_value()?); - } - GeneratedField::IfNotExists => { - if if_not_exists__.is_some() { - return Err(serde::de::Error::duplicate_field("ifNotExists")); + GeneratedField::PrimaryKey => { + if constraint_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("primaryKey")); } - if_not_exists__ = Some(map_.next_value()?); + constraint_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(constraint::ConstraintMode::PrimaryKey) +; } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Unique => { + if constraint_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("unique")); } - schema__ = map_.next_value()?; + constraint_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(constraint::ConstraintMode::Unique) +; } } } - Ok(CreateCatalogNode { - catalog_name: catalog_name__.unwrap_or_default(), - if_not_exists: if_not_exists__.unwrap_or_default(), - schema: schema__, + Ok(Constraint { + constraint_mode: constraint_mode__, }) } } - deserializer.deserialize_struct("datafusion.CreateCatalogNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.Constraint", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CreateCatalogSchemaNode { +impl serde::Serialize for Constraints { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -3560,47 +3621,29 @@ impl serde::Serialize for CreateCatalogSchemaNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.schema_name.is_empty() { - len += 1; - } - if self.if_not_exists { - len += 1; - } - if self.schema.is_some() { + if !self.constraints.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CreateCatalogSchemaNode", len)?; - if !self.schema_name.is_empty() { - struct_ser.serialize_field("schemaName", &self.schema_name)?; - } - if self.if_not_exists { - struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; - } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.Constraints", len)?; + if !self.constraints.is_empty() { + struct_ser.serialize_field("constraints", &self.constraints)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CreateCatalogSchemaNode { +impl<'de> serde::Deserialize<'de> for Constraints { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "schema_name", - "schemaName", - "if_not_exists", - "ifNotExists", - "schema", + "constraints", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - SchemaName, - IfNotExists, - Schema, + Constraints, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -3622,9 +3665,7 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogSchemaNode { E: serde::de::Error, { match value { - "schemaName" | "schema_name" => Ok(GeneratedField::SchemaName), - "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), - "schema" => Ok(GeneratedField::Schema), + "constraints" => Ok(GeneratedField::Constraints), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -3634,52 +3675,36 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogSchemaNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CreateCatalogSchemaNode; + type Value = Constraints; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CreateCatalogSchemaNode") + formatter.write_str("struct datafusion.Constraints") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut schema_name__ = None; - let mut if_not_exists__ = None; - let mut schema__ = None; + let mut constraints__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::SchemaName => { - if schema_name__.is_some() { - return Err(serde::de::Error::duplicate_field("schemaName")); - } - schema_name__ = Some(map_.next_value()?); - } - GeneratedField::IfNotExists => { - if if_not_exists__.is_some() { - return Err(serde::de::Error::duplicate_field("ifNotExists")); + GeneratedField::Constraints => { + if constraints__.is_some() { + return Err(serde::de::Error::duplicate_field("constraints")); } - if_not_exists__ = Some(map_.next_value()?); - } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); - } - schema__ = map_.next_value()?; + constraints__ = Some(map_.next_value()?); } } } - Ok(CreateCatalogSchemaNode { - schema_name: schema_name__.unwrap_or_default(), - if_not_exists: if_not_exists__.unwrap_or_default(), - schema: schema__, + Ok(Constraints { + constraints: constraints__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CreateCatalogSchemaNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.Constraints", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CreateExternalTableNode { +impl serde::Serialize for CopyToNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -3687,131 +3712,75 @@ impl serde::Serialize for CreateExternalTableNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.name.is_some() { - len += 1; - } - if !self.location.is_empty() { - len += 1; - } - if !self.file_type.is_empty() { - len += 1; - } - if self.has_header { - len += 1; - } - if self.schema.is_some() { - len += 1; - } - if !self.table_partition_cols.is_empty() { - len += 1; - } - if self.if_not_exists { - len += 1; - } - if !self.delimiter.is_empty() { + if self.input.is_some() { len += 1; } - if !self.definition.is_empty() { + if !self.output_url.is_empty() { len += 1; } - if !self.file_compression_type.is_empty() { + if self.single_file_output { len += 1; } - if !self.order_exprs.is_empty() { + if !self.file_type.is_empty() { len += 1; } - if self.unbounded { + if self.copy_options.is_some() { len += 1; } - if !self.options.is_empty() { - len += 1; + let mut struct_ser = serializer.serialize_struct("datafusion.CopyToNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } - let mut struct_ser = serializer.serialize_struct("datafusion.CreateExternalTableNode", len)?; - if let Some(v) = self.name.as_ref() { - struct_ser.serialize_field("name", v)?; + if !self.output_url.is_empty() { + struct_ser.serialize_field("outputUrl", &self.output_url)?; } - if !self.location.is_empty() { - struct_ser.serialize_field("location", &self.location)?; + if self.single_file_output { + struct_ser.serialize_field("singleFileOutput", &self.single_file_output)?; } if !self.file_type.is_empty() { struct_ser.serialize_field("fileType", &self.file_type)?; } - if self.has_header { - struct_ser.serialize_field("hasHeader", &self.has_header)?; - } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; - } - if !self.table_partition_cols.is_empty() { - struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; - } - if self.if_not_exists { - struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; - } - if !self.delimiter.is_empty() { - struct_ser.serialize_field("delimiter", &self.delimiter)?; - } - if !self.definition.is_empty() { - struct_ser.serialize_field("definition", &self.definition)?; - } - if !self.file_compression_type.is_empty() { - struct_ser.serialize_field("fileCompressionType", &self.file_compression_type)?; - } - if !self.order_exprs.is_empty() { - struct_ser.serialize_field("orderExprs", &self.order_exprs)?; - } - if self.unbounded { - struct_ser.serialize_field("unbounded", &self.unbounded)?; - } - if !self.options.is_empty() { - struct_ser.serialize_field("options", &self.options)?; + if let Some(v) = self.copy_options.as_ref() { + match v { + copy_to_node::CopyOptions::SqlOptions(v) => { + struct_ser.serialize_field("sqlOptions", v)?; + } + copy_to_node::CopyOptions::WriterOptions(v) => { + struct_ser.serialize_field("writerOptions", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { +impl<'de> serde::Deserialize<'de> for CopyToNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", - "location", + "input", + "output_url", + "outputUrl", + "single_file_output", + "singleFileOutput", "file_type", "fileType", - "has_header", - "hasHeader", - "schema", - "table_partition_cols", - "tablePartitionCols", - "if_not_exists", - "ifNotExists", - "delimiter", - "definition", - "file_compression_type", - "fileCompressionType", - "order_exprs", - "orderExprs", - "unbounded", - "options", + "sql_options", + "sqlOptions", + "writer_options", + "writerOptions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, - Location, + Input, + OutputUrl, + SingleFileOutput, FileType, - HasHeader, - Schema, - TablePartitionCols, - IfNotExists, - Delimiter, - Definition, - FileCompressionType, - OrderExprs, - Unbounded, - Options, + SqlOptions, + WriterOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -3833,19 +3802,12 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), - "location" => Ok(GeneratedField::Location), + "input" => Ok(GeneratedField::Input), + "outputUrl" | "output_url" => Ok(GeneratedField::OutputUrl), + "singleFileOutput" | "single_file_output" => Ok(GeneratedField::SingleFileOutput), "fileType" | "file_type" => Ok(GeneratedField::FileType), - "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), - "schema" => Ok(GeneratedField::Schema), - "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), - "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), - "delimiter" => Ok(GeneratedField::Delimiter), - "definition" => Ok(GeneratedField::Definition), - "fileCompressionType" | "file_compression_type" => Ok(GeneratedField::FileCompressionType), - "orderExprs" | "order_exprs" => Ok(GeneratedField::OrderExprs), - "unbounded" => Ok(GeneratedField::Unbounded), - "options" => Ok(GeneratedField::Options), + "sqlOptions" | "sql_options" => Ok(GeneratedField::SqlOptions), + "writerOptions" | "writer_options" => Ok(GeneratedField::WriterOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -3855,42 +3817,40 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CreateExternalTableNode; + type Value = CopyToNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CreateExternalTableNode") + formatter.write_str("struct datafusion.CopyToNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; - let mut location__ = None; + let mut input__ = None; + let mut output_url__ = None; + let mut single_file_output__ = None; let mut file_type__ = None; - let mut has_header__ = None; - let mut schema__ = None; - let mut table_partition_cols__ = None; - let mut if_not_exists__ = None; - let mut delimiter__ = None; - let mut definition__ = None; - let mut file_compression_type__ = None; - let mut order_exprs__ = None; - let mut unbounded__ = None; - let mut options__ = None; + let mut copy_options__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - name__ = map_.next_value()?; + input__ = map_.next_value()?; } - GeneratedField::Location => { - if location__.is_some() { - return Err(serde::de::Error::duplicate_field("location")); + GeneratedField::OutputUrl => { + if output_url__.is_some() { + return Err(serde::de::Error::duplicate_field("outputUrl")); } - location__ = Some(map_.next_value()?); + output_url__ = Some(map_.next_value()?); + } + GeneratedField::SingleFileOutput => { + if single_file_output__.is_some() { + return Err(serde::de::Error::duplicate_field("singleFileOutput")); + } + single_file_output__ = Some(map_.next_value()?); } GeneratedField::FileType => { if file_type__.is_some() { @@ -3898,91 +3858,35 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { } file_type__ = Some(map_.next_value()?); } - GeneratedField::HasHeader => { - if has_header__.is_some() { - return Err(serde::de::Error::duplicate_field("hasHeader")); - } - has_header__ = Some(map_.next_value()?); - } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); - } - schema__ = map_.next_value()?; - } - GeneratedField::TablePartitionCols => { - if table_partition_cols__.is_some() { - return Err(serde::de::Error::duplicate_field("tablePartitionCols")); - } - table_partition_cols__ = Some(map_.next_value()?); - } - GeneratedField::IfNotExists => { - if if_not_exists__.is_some() { - return Err(serde::de::Error::duplicate_field("ifNotExists")); - } - if_not_exists__ = Some(map_.next_value()?); - } - GeneratedField::Delimiter => { - if delimiter__.is_some() { - return Err(serde::de::Error::duplicate_field("delimiter")); - } - delimiter__ = Some(map_.next_value()?); - } - GeneratedField::Definition => { - if definition__.is_some() { - return Err(serde::de::Error::duplicate_field("definition")); - } - definition__ = Some(map_.next_value()?); - } - GeneratedField::FileCompressionType => { - if file_compression_type__.is_some() { - return Err(serde::de::Error::duplicate_field("fileCompressionType")); - } - file_compression_type__ = Some(map_.next_value()?); - } - GeneratedField::OrderExprs => { - if order_exprs__.is_some() { - return Err(serde::de::Error::duplicate_field("orderExprs")); - } - order_exprs__ = Some(map_.next_value()?); - } - GeneratedField::Unbounded => { - if unbounded__.is_some() { - return Err(serde::de::Error::duplicate_field("unbounded")); + GeneratedField::SqlOptions => { + if copy_options__.is_some() { + return Err(serde::de::Error::duplicate_field("sqlOptions")); } - unbounded__ = Some(map_.next_value()?); + copy_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::CopyOptions::SqlOptions) +; } - GeneratedField::Options => { - if options__.is_some() { - return Err(serde::de::Error::duplicate_field("options")); + GeneratedField::WriterOptions => { + if copy_options__.is_some() { + return Err(serde::de::Error::duplicate_field("writerOptions")); } - options__ = Some( - map_.next_value::>()? - ); + copy_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::CopyOptions::WriterOptions) +; } } } - Ok(CreateExternalTableNode { - name: name__, - location: location__.unwrap_or_default(), + Ok(CopyToNode { + input: input__, + output_url: output_url__.unwrap_or_default(), + single_file_output: single_file_output__.unwrap_or_default(), file_type: file_type__.unwrap_or_default(), - has_header: has_header__.unwrap_or_default(), - schema: schema__, - table_partition_cols: table_partition_cols__.unwrap_or_default(), - if_not_exists: if_not_exists__.unwrap_or_default(), - delimiter: delimiter__.unwrap_or_default(), - definition: definition__.unwrap_or_default(), - file_compression_type: file_compression_type__.unwrap_or_default(), - order_exprs: order_exprs__.unwrap_or_default(), - unbounded: unbounded__.unwrap_or_default(), - options: options__.unwrap_or_default(), + copy_options: copy_options__, }) } } - deserializer.deserialize_struct("datafusion.CreateExternalTableNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CopyToNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CreateViewNode { +impl serde::Serialize for CreateCatalogNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -3990,54 +3894,47 @@ impl serde::Serialize for CreateViewNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.name.is_some() { - len += 1; - } - if self.input.is_some() { + if !self.catalog_name.is_empty() { len += 1; } - if self.or_replace { + if self.if_not_exists { len += 1; } - if !self.definition.is_empty() { + if self.schema.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CreateViewNode", len)?; - if let Some(v) = self.name.as_ref() { - struct_ser.serialize_field("name", v)?; - } - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.CreateCatalogNode", len)?; + if !self.catalog_name.is_empty() { + struct_ser.serialize_field("catalogName", &self.catalog_name)?; } - if self.or_replace { - struct_ser.serialize_field("orReplace", &self.or_replace)?; + if self.if_not_exists { + struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; } - if !self.definition.is_empty() { - struct_ser.serialize_field("definition", &self.definition)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CreateViewNode { +impl<'de> serde::Deserialize<'de> for CreateCatalogNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", - "input", - "or_replace", - "orReplace", - "definition", + "catalog_name", + "catalogName", + "if_not_exists", + "ifNotExists", + "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, - Input, - OrReplace, - Definition, + CatalogName, + IfNotExists, + Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4059,10 +3956,9 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), - "input" => Ok(GeneratedField::Input), - "orReplace" | "or_replace" => Ok(GeneratedField::OrReplace), - "definition" => Ok(GeneratedField::Definition), + "catalogName" | "catalog_name" => Ok(GeneratedField::CatalogName), + "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), + "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4072,60 +3968,52 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CreateViewNode; + type Value = CreateCatalogNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CreateViewNode") + formatter.write_str("struct datafusion.CreateCatalogNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; - let mut input__ = None; - let mut or_replace__ = None; - let mut definition__ = None; + let mut catalog_name__ = None; + let mut if_not_exists__ = None; + let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); + GeneratedField::CatalogName => { + if catalog_name__.is_some() { + return Err(serde::de::Error::duplicate_field("catalogName")); } - name__ = map_.next_value()?; + catalog_name__ = Some(map_.next_value()?); } - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::IfNotExists => { + if if_not_exists__.is_some() { + return Err(serde::de::Error::duplicate_field("ifNotExists")); } - input__ = map_.next_value()?; + if_not_exists__ = Some(map_.next_value()?); } - GeneratedField::OrReplace => { - if or_replace__.is_some() { - return Err(serde::de::Error::duplicate_field("orReplace")); - } - or_replace__ = Some(map_.next_value()?); - } - GeneratedField::Definition => { - if definition__.is_some() { - return Err(serde::de::Error::duplicate_field("definition")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - definition__ = Some(map_.next_value()?); + schema__ = map_.next_value()?; } } } - Ok(CreateViewNode { - name: name__, - input: input__, - or_replace: or_replace__.unwrap_or_default(), - definition: definition__.unwrap_or_default(), + Ok(CreateCatalogNode { + catalog_name: catalog_name__.unwrap_or_default(), + if_not_exists: if_not_exists__.unwrap_or_default(), + schema: schema__, }) } } - deserializer.deserialize_struct("datafusion.CreateViewNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CreateCatalogNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CrossJoinExecNode { +impl serde::Serialize for CreateCatalogSchemaNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4133,37 +4021,47 @@ impl serde::Serialize for CrossJoinExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.left.is_some() { + if !self.schema_name.is_empty() { len += 1; } - if self.right.is_some() { + if self.if_not_exists { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CrossJoinExecNode", len)?; - if let Some(v) = self.left.as_ref() { - struct_ser.serialize_field("left", v)?; + if self.schema.is_some() { + len += 1; } - if let Some(v) = self.right.as_ref() { - struct_ser.serialize_field("right", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.CreateCatalogSchemaNode", len)?; + if !self.schema_name.is_empty() { + struct_ser.serialize_field("schemaName", &self.schema_name)?; + } + if self.if_not_exists { + struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CrossJoinExecNode { +impl<'de> serde::Deserialize<'de> for CreateCatalogSchemaNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "left", - "right", + "schema_name", + "schemaName", + "if_not_exists", + "ifNotExists", + "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Left, - Right, + SchemaName, + IfNotExists, + Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4185,8 +4083,9 @@ impl<'de> serde::Deserialize<'de> for CrossJoinExecNode { E: serde::de::Error, { match value { - "left" => Ok(GeneratedField::Left), - "right" => Ok(GeneratedField::Right), + "schemaName" | "schema_name" => Ok(GeneratedField::SchemaName), + "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), + "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4196,44 +4095,52 @@ impl<'de> serde::Deserialize<'de> for CrossJoinExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CrossJoinExecNode; + type Value = CreateCatalogSchemaNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CrossJoinExecNode") + formatter.write_str("struct datafusion.CreateCatalogSchemaNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut left__ = None; - let mut right__ = None; + let mut schema_name__ = None; + let mut if_not_exists__ = None; + let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Left => { - if left__.is_some() { - return Err(serde::de::Error::duplicate_field("left")); + GeneratedField::SchemaName => { + if schema_name__.is_some() { + return Err(serde::de::Error::duplicate_field("schemaName")); } - left__ = map_.next_value()?; + schema_name__ = Some(map_.next_value()?); } - GeneratedField::Right => { - if right__.is_some() { - return Err(serde::de::Error::duplicate_field("right")); + GeneratedField::IfNotExists => { + if if_not_exists__.is_some() { + return Err(serde::de::Error::duplicate_field("ifNotExists")); } - right__ = map_.next_value()?; + if_not_exists__ = Some(map_.next_value()?); + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; } } } - Ok(CrossJoinExecNode { - left: left__, - right: right__, + Ok(CreateCatalogSchemaNode { + schema_name: schema_name__.unwrap_or_default(), + if_not_exists: if_not_exists__.unwrap_or_default(), + schema: schema__, }) } } - deserializer.deserialize_struct("datafusion.CrossJoinExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CreateCatalogSchemaNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CrossJoinNode { +impl serde::Serialize for CreateExternalTableNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4241,37 +4148,148 @@ impl serde::Serialize for CrossJoinNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.left.is_some() { + if self.name.is_some() { len += 1; } - if self.right.is_some() { + if !self.location.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CrossJoinNode", len)?; - if let Some(v) = self.left.as_ref() { - struct_ser.serialize_field("left", v)?; + if !self.file_type.is_empty() { + len += 1; } - if let Some(v) = self.right.as_ref() { - struct_ser.serialize_field("right", v)?; + if self.has_header { + len += 1; + } + if self.schema.is_some() { + len += 1; + } + if !self.table_partition_cols.is_empty() { + len += 1; + } + if self.if_not_exists { + len += 1; + } + if !self.delimiter.is_empty() { + len += 1; + } + if !self.definition.is_empty() { + len += 1; + } + if !self.file_compression_type.is_empty() { + len += 1; + } + if !self.order_exprs.is_empty() { + len += 1; + } + if self.unbounded { + len += 1; + } + if !self.options.is_empty() { + len += 1; + } + if self.constraints.is_some() { + len += 1; + } + if !self.column_defaults.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CreateExternalTableNode", len)?; + if let Some(v) = self.name.as_ref() { + struct_ser.serialize_field("name", v)?; + } + if !self.location.is_empty() { + struct_ser.serialize_field("location", &self.location)?; + } + if !self.file_type.is_empty() { + struct_ser.serialize_field("fileType", &self.file_type)?; + } + if self.has_header { + struct_ser.serialize_field("hasHeader", &self.has_header)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + if !self.table_partition_cols.is_empty() { + struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; + } + if self.if_not_exists { + struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; + } + if !self.delimiter.is_empty() { + struct_ser.serialize_field("delimiter", &self.delimiter)?; + } + if !self.definition.is_empty() { + struct_ser.serialize_field("definition", &self.definition)?; + } + if !self.file_compression_type.is_empty() { + struct_ser.serialize_field("fileCompressionType", &self.file_compression_type)?; + } + if !self.order_exprs.is_empty() { + struct_ser.serialize_field("orderExprs", &self.order_exprs)?; + } + if self.unbounded { + struct_ser.serialize_field("unbounded", &self.unbounded)?; + } + if !self.options.is_empty() { + struct_ser.serialize_field("options", &self.options)?; + } + if let Some(v) = self.constraints.as_ref() { + struct_ser.serialize_field("constraints", v)?; + } + if !self.column_defaults.is_empty() { + struct_ser.serialize_field("columnDefaults", &self.column_defaults)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CrossJoinNode { +impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "left", - "right", + "name", + "location", + "file_type", + "fileType", + "has_header", + "hasHeader", + "schema", + "table_partition_cols", + "tablePartitionCols", + "if_not_exists", + "ifNotExists", + "delimiter", + "definition", + "file_compression_type", + "fileCompressionType", + "order_exprs", + "orderExprs", + "unbounded", + "options", + "constraints", + "column_defaults", + "columnDefaults", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Left, - Right, + Name, + Location, + FileType, + HasHeader, + Schema, + TablePartitionCols, + IfNotExists, + Delimiter, + Definition, + FileCompressionType, + OrderExprs, + Unbounded, + Options, + Constraints, + ColumnDefaults, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4293,8 +4311,21 @@ impl<'de> serde::Deserialize<'de> for CrossJoinNode { E: serde::de::Error, { match value { - "left" => Ok(GeneratedField::Left), - "right" => Ok(GeneratedField::Right), + "name" => Ok(GeneratedField::Name), + "location" => Ok(GeneratedField::Location), + "fileType" | "file_type" => Ok(GeneratedField::FileType), + "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), + "schema" => Ok(GeneratedField::Schema), + "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), + "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), + "delimiter" => Ok(GeneratedField::Delimiter), + "definition" => Ok(GeneratedField::Definition), + "fileCompressionType" | "file_compression_type" => Ok(GeneratedField::FileCompressionType), + "orderExprs" | "order_exprs" => Ok(GeneratedField::OrderExprs), + "unbounded" => Ok(GeneratedField::Unbounded), + "options" => Ok(GeneratedField::Options), + "constraints" => Ok(GeneratedField::Constraints), + "columnDefaults" | "column_defaults" => Ok(GeneratedField::ColumnDefaults), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4304,44 +4335,152 @@ impl<'de> serde::Deserialize<'de> for CrossJoinNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CrossJoinNode; + type Value = CreateExternalTableNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CrossJoinNode") + formatter.write_str("struct datafusion.CreateExternalTableNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut left__ = None; - let mut right__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Left => { - if left__.is_some() { - return Err(serde::de::Error::duplicate_field("left")); + let mut name__ = None; + let mut location__ = None; + let mut file_type__ = None; + let mut has_header__ = None; + let mut schema__ = None; + let mut table_partition_cols__ = None; + let mut if_not_exists__ = None; + let mut delimiter__ = None; + let mut definition__ = None; + let mut file_compression_type__ = None; + let mut order_exprs__ = None; + let mut unbounded__ = None; + let mut options__ = None; + let mut constraints__ = None; + let mut column_defaults__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - left__ = map_.next_value()?; + name__ = map_.next_value()?; } - GeneratedField::Right => { - if right__.is_some() { - return Err(serde::de::Error::duplicate_field("right")); + GeneratedField::Location => { + if location__.is_some() { + return Err(serde::de::Error::duplicate_field("location")); } - right__ = map_.next_value()?; + location__ = Some(map_.next_value()?); + } + GeneratedField::FileType => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fileType")); + } + file_type__ = Some(map_.next_value()?); + } + GeneratedField::HasHeader => { + if has_header__.is_some() { + return Err(serde::de::Error::duplicate_field("hasHeader")); + } + has_header__ = Some(map_.next_value()?); + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + GeneratedField::TablePartitionCols => { + if table_partition_cols__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePartitionCols")); + } + table_partition_cols__ = Some(map_.next_value()?); + } + GeneratedField::IfNotExists => { + if if_not_exists__.is_some() { + return Err(serde::de::Error::duplicate_field("ifNotExists")); + } + if_not_exists__ = Some(map_.next_value()?); + } + GeneratedField::Delimiter => { + if delimiter__.is_some() { + return Err(serde::de::Error::duplicate_field("delimiter")); + } + delimiter__ = Some(map_.next_value()?); + } + GeneratedField::Definition => { + if definition__.is_some() { + return Err(serde::de::Error::duplicate_field("definition")); + } + definition__ = Some(map_.next_value()?); + } + GeneratedField::FileCompressionType => { + if file_compression_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fileCompressionType")); + } + file_compression_type__ = Some(map_.next_value()?); + } + GeneratedField::OrderExprs => { + if order_exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("orderExprs")); + } + order_exprs__ = Some(map_.next_value()?); + } + GeneratedField::Unbounded => { + if unbounded__.is_some() { + return Err(serde::de::Error::duplicate_field("unbounded")); + } + unbounded__ = Some(map_.next_value()?); + } + GeneratedField::Options => { + if options__.is_some() { + return Err(serde::de::Error::duplicate_field("options")); + } + options__ = Some( + map_.next_value::>()? + ); + } + GeneratedField::Constraints => { + if constraints__.is_some() { + return Err(serde::de::Error::duplicate_field("constraints")); + } + constraints__ = map_.next_value()?; + } + GeneratedField::ColumnDefaults => { + if column_defaults__.is_some() { + return Err(serde::de::Error::duplicate_field("columnDefaults")); + } + column_defaults__ = Some( + map_.next_value::>()? + ); } } } - Ok(CrossJoinNode { - left: left__, - right: right__, + Ok(CreateExternalTableNode { + name: name__, + location: location__.unwrap_or_default(), + file_type: file_type__.unwrap_or_default(), + has_header: has_header__.unwrap_or_default(), + schema: schema__, + table_partition_cols: table_partition_cols__.unwrap_or_default(), + if_not_exists: if_not_exists__.unwrap_or_default(), + delimiter: delimiter__.unwrap_or_default(), + definition: definition__.unwrap_or_default(), + file_compression_type: file_compression_type__.unwrap_or_default(), + order_exprs: order_exprs__.unwrap_or_default(), + unbounded: unbounded__.unwrap_or_default(), + options: options__.unwrap_or_default(), + constraints: constraints__, + column_defaults: column_defaults__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CrossJoinNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CreateExternalTableNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CsvFormat { +impl serde::Serialize for CreateViewNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4349,58 +4488,54 @@ impl serde::Serialize for CsvFormat { { use serde::ser::SerializeStruct; let mut len = 0; - if self.has_header { + if self.name.is_some() { len += 1; } - if !self.delimiter.is_empty() { + if self.input.is_some() { len += 1; } - if !self.quote.is_empty() { + if self.or_replace { len += 1; } - if self.optional_escape.is_some() { + if !self.definition.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CsvFormat", len)?; - if self.has_header { - struct_ser.serialize_field("hasHeader", &self.has_header)?; + let mut struct_ser = serializer.serialize_struct("datafusion.CreateViewNode", len)?; + if let Some(v) = self.name.as_ref() { + struct_ser.serialize_field("name", v)?; } - if !self.delimiter.is_empty() { - struct_ser.serialize_field("delimiter", &self.delimiter)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } - if !self.quote.is_empty() { - struct_ser.serialize_field("quote", &self.quote)?; + if self.or_replace { + struct_ser.serialize_field("orReplace", &self.or_replace)?; } - if let Some(v) = self.optional_escape.as_ref() { - match v { - csv_format::OptionalEscape::Escape(v) => { - struct_ser.serialize_field("escape", v)?; - } - } + if !self.definition.is_empty() { + struct_ser.serialize_field("definition", &self.definition)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CsvFormat { +impl<'de> serde::Deserialize<'de> for CreateViewNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "has_header", - "hasHeader", - "delimiter", - "quote", - "escape", + "name", + "input", + "or_replace", + "orReplace", + "definition", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - HasHeader, - Delimiter, - Quote, - Escape, + Name, + Input, + OrReplace, + Definition, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4422,10 +4557,10 @@ impl<'de> serde::Deserialize<'de> for CsvFormat { E: serde::de::Error, { match value { - "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), - "delimiter" => Ok(GeneratedField::Delimiter), - "quote" => Ok(GeneratedField::Quote), - "escape" => Ok(GeneratedField::Escape), + "name" => Ok(GeneratedField::Name), + "input" => Ok(GeneratedField::Input), + "orReplace" | "or_replace" => Ok(GeneratedField::OrReplace), + "definition" => Ok(GeneratedField::Definition), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4435,60 +4570,60 @@ impl<'de> serde::Deserialize<'de> for CsvFormat { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CsvFormat; + type Value = CreateViewNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CsvFormat") + formatter.write_str("struct datafusion.CreateViewNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut has_header__ = None; - let mut delimiter__ = None; - let mut quote__ = None; - let mut optional_escape__ = None; + let mut name__ = None; + let mut input__ = None; + let mut or_replace__ = None; + let mut definition__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::HasHeader => { - if has_header__.is_some() { - return Err(serde::de::Error::duplicate_field("hasHeader")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - has_header__ = Some(map_.next_value()?); + name__ = map_.next_value()?; } - GeneratedField::Delimiter => { - if delimiter__.is_some() { - return Err(serde::de::Error::duplicate_field("delimiter")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - delimiter__ = Some(map_.next_value()?); + input__ = map_.next_value()?; } - GeneratedField::Quote => { - if quote__.is_some() { - return Err(serde::de::Error::duplicate_field("quote")); + GeneratedField::OrReplace => { + if or_replace__.is_some() { + return Err(serde::de::Error::duplicate_field("orReplace")); } - quote__ = Some(map_.next_value()?); + or_replace__ = Some(map_.next_value()?); } - GeneratedField::Escape => { - if optional_escape__.is_some() { - return Err(serde::de::Error::duplicate_field("escape")); + GeneratedField::Definition => { + if definition__.is_some() { + return Err(serde::de::Error::duplicate_field("definition")); } - optional_escape__ = map_.next_value::<::std::option::Option<_>>()?.map(csv_format::OptionalEscape::Escape); + definition__ = Some(map_.next_value()?); } } } - Ok(CsvFormat { - has_header: has_header__.unwrap_or_default(), - delimiter: delimiter__.unwrap_or_default(), - quote: quote__.unwrap_or_default(), - optional_escape: optional_escape__, + Ok(CreateViewNode { + name: name__, + input: input__, + or_replace: or_replace__.unwrap_or_default(), + definition: definition__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CsvFormat", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CreateViewNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CsvScanExecNode { +impl serde::Serialize for CrossJoinExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4496,67 +4631,37 @@ impl serde::Serialize for CsvScanExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.base_conf.is_some() { + if self.left.is_some() { len += 1; } - if self.has_header { + if self.right.is_some() { len += 1; } - if !self.delimiter.is_empty() { - len += 1; + let mut struct_ser = serializer.serialize_struct("datafusion.CrossJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; } - if !self.quote.is_empty() { - len += 1; - } - if self.optional_escape.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.CsvScanExecNode", len)?; - if let Some(v) = self.base_conf.as_ref() { - struct_ser.serialize_field("baseConf", v)?; - } - if self.has_header { - struct_ser.serialize_field("hasHeader", &self.has_header)?; - } - if !self.delimiter.is_empty() { - struct_ser.serialize_field("delimiter", &self.delimiter)?; - } - if !self.quote.is_empty() { - struct_ser.serialize_field("quote", &self.quote)?; - } - if let Some(v) = self.optional_escape.as_ref() { - match v { - csv_scan_exec_node::OptionalEscape::Escape(v) => { - struct_ser.serialize_field("escape", v)?; - } - } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CsvScanExecNode { +impl<'de> serde::Deserialize<'de> for CrossJoinExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "base_conf", - "baseConf", - "has_header", - "hasHeader", - "delimiter", - "quote", - "escape", + "left", + "right", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - BaseConf, - HasHeader, - Delimiter, - Quote, - Escape, + Left, + Right, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4578,11 +4683,8 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { E: serde::de::Error, { match value { - "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), - "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), - "delimiter" => Ok(GeneratedField::Delimiter), - "quote" => Ok(GeneratedField::Quote), - "escape" => Ok(GeneratedField::Escape), + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4592,68 +4694,44 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CsvScanExecNode; + type Value = CrossJoinExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CsvScanExecNode") + formatter.write_str("struct datafusion.CrossJoinExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut base_conf__ = None; - let mut has_header__ = None; - let mut delimiter__ = None; - let mut quote__ = None; - let mut optional_escape__ = None; + let mut left__ = None; + let mut right__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::BaseConf => { - if base_conf__.is_some() { - return Err(serde::de::Error::duplicate_field("baseConf")); - } - base_conf__ = map_.next_value()?; - } - GeneratedField::HasHeader => { - if has_header__.is_some() { - return Err(serde::de::Error::duplicate_field("hasHeader")); - } - has_header__ = Some(map_.next_value()?); - } - GeneratedField::Delimiter => { - if delimiter__.is_some() { - return Err(serde::de::Error::duplicate_field("delimiter")); - } - delimiter__ = Some(map_.next_value()?); - } - GeneratedField::Quote => { - if quote__.is_some() { - return Err(serde::de::Error::duplicate_field("quote")); + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); } - quote__ = Some(map_.next_value()?); + left__ = map_.next_value()?; } - GeneratedField::Escape => { - if optional_escape__.is_some() { - return Err(serde::de::Error::duplicate_field("escape")); + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); } - optional_escape__ = map_.next_value::<::std::option::Option<_>>()?.map(csv_scan_exec_node::OptionalEscape::Escape); + right__ = map_.next_value()?; } } } - Ok(CsvScanExecNode { - base_conf: base_conf__, - has_header: has_header__.unwrap_or_default(), - delimiter: delimiter__.unwrap_or_default(), - quote: quote__.unwrap_or_default(), - optional_escape: optional_escape__, + Ok(CrossJoinExecNode { + left: left__, + right: right__, }) } } - deserializer.deserialize_struct("datafusion.CsvScanExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CrossJoinExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CubeNode { +impl serde::Serialize for CrossJoinNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4661,29 +4739,37 @@ impl serde::Serialize for CubeNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.expr.is_empty() { + if self.left.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CubeNode", len)?; - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; + if self.right.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CrossJoinNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CubeNode { +impl<'de> serde::Deserialize<'de> for CrossJoinNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "left", + "right", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Left, + Right, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4705,7 +4791,8 @@ impl<'de> serde::Deserialize<'de> for CubeNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4715,36 +4802,44 @@ impl<'de> serde::Deserialize<'de> for CubeNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CubeNode; + type Value = CrossJoinNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CubeNode") + formatter.write_str("struct datafusion.CrossJoinNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut left__ = None; + let mut right__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); } - expr__ = Some(map_.next_value()?); + left__ = map_.next_value()?; + } + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); + } + right__ = map_.next_value()?; } } } - Ok(CubeNode { - expr: expr__.unwrap_or_default(), + Ok(CrossJoinNode { + left: left__, + right: right__, }) } } - deserializer.deserialize_struct("datafusion.CubeNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CrossJoinNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CustomTableScanNode { +impl serde::Serialize for CsvFormat { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4752,64 +4847,58 @@ impl serde::Serialize for CustomTableScanNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.table_name.is_some() { - len += 1; - } - if self.projection.is_some() { + if self.has_header { len += 1; } - if self.schema.is_some() { + if !self.delimiter.is_empty() { len += 1; } - if !self.filters.is_empty() { + if !self.quote.is_empty() { len += 1; } - if !self.custom_table_data.is_empty() { + if self.optional_escape.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CustomTableScanNode", len)?; - if let Some(v) = self.table_name.as_ref() { - struct_ser.serialize_field("tableName", v)?; - } - if let Some(v) = self.projection.as_ref() { - struct_ser.serialize_field("projection", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.CsvFormat", len)?; + if self.has_header { + struct_ser.serialize_field("hasHeader", &self.has_header)?; } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + if !self.delimiter.is_empty() { + struct_ser.serialize_field("delimiter", &self.delimiter)?; } - if !self.filters.is_empty() { - struct_ser.serialize_field("filters", &self.filters)?; + if !self.quote.is_empty() { + struct_ser.serialize_field("quote", &self.quote)?; } - if !self.custom_table_data.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("customTableData", pbjson::private::base64::encode(&self.custom_table_data).as_str())?; + if let Some(v) = self.optional_escape.as_ref() { + match v { + csv_format::OptionalEscape::Escape(v) => { + struct_ser.serialize_field("escape", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CustomTableScanNode { +impl<'de> serde::Deserialize<'de> for CsvFormat { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "table_name", - "tableName", - "projection", - "schema", - "filters", - "custom_table_data", - "customTableData", + "has_header", + "hasHeader", + "delimiter", + "quote", + "escape", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - TableName, - Projection, - Schema, - Filters, - CustomTableData, + HasHeader, + Delimiter, + Quote, + Escape, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4831,11 +4920,10 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { E: serde::de::Error, { match value { - "tableName" | "table_name" => Ok(GeneratedField::TableName), - "projection" => Ok(GeneratedField::Projection), - "schema" => Ok(GeneratedField::Schema), - "filters" => Ok(GeneratedField::Filters), - "customTableData" | "custom_table_data" => Ok(GeneratedField::CustomTableData), + "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), + "delimiter" => Ok(GeneratedField::Delimiter), + "quote" => Ok(GeneratedField::Quote), + "escape" => Ok(GeneratedField::Escape), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4845,141 +4933,225 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CustomTableScanNode; + type Value = CsvFormat; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CustomTableScanNode") + formatter.write_str("struct datafusion.CsvFormat") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut table_name__ = None; - let mut projection__ = None; - let mut schema__ = None; - let mut filters__ = None; - let mut custom_table_data__ = None; + let mut has_header__ = None; + let mut delimiter__ = None; + let mut quote__ = None; + let mut optional_escape__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::TableName => { - if table_name__.is_some() { - return Err(serde::de::Error::duplicate_field("tableName")); + GeneratedField::HasHeader => { + if has_header__.is_some() { + return Err(serde::de::Error::duplicate_field("hasHeader")); } - table_name__ = map_.next_value()?; - } - GeneratedField::Projection => { - if projection__.is_some() { - return Err(serde::de::Error::duplicate_field("projection")); - } - projection__ = map_.next_value()?; + has_header__ = Some(map_.next_value()?); } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Delimiter => { + if delimiter__.is_some() { + return Err(serde::de::Error::duplicate_field("delimiter")); } - schema__ = map_.next_value()?; + delimiter__ = Some(map_.next_value()?); } - GeneratedField::Filters => { - if filters__.is_some() { - return Err(serde::de::Error::duplicate_field("filters")); + GeneratedField::Quote => { + if quote__.is_some() { + return Err(serde::de::Error::duplicate_field("quote")); } - filters__ = Some(map_.next_value()?); + quote__ = Some(map_.next_value()?); } - GeneratedField::CustomTableData => { - if custom_table_data__.is_some() { - return Err(serde::de::Error::duplicate_field("customTableData")); + GeneratedField::Escape => { + if optional_escape__.is_some() { + return Err(serde::de::Error::duplicate_field("escape")); } - custom_table_data__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) - ; + optional_escape__ = map_.next_value::<::std::option::Option<_>>()?.map(csv_format::OptionalEscape::Escape); } } } - Ok(CustomTableScanNode { - table_name: table_name__, - projection: projection__, - schema: schema__, - filters: filters__.unwrap_or_default(), - custom_table_data: custom_table_data__.unwrap_or_default(), + Ok(CsvFormat { + has_header: has_header__.unwrap_or_default(), + delimiter: delimiter__.unwrap_or_default(), + quote: quote__.unwrap_or_default(), + optional_escape: optional_escape__, }) } } - deserializer.deserialize_struct("datafusion.CustomTableScanNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CsvFormat", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for DateUnit { +impl serde::Serialize for CsvScanExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { - let variant = match self { - Self::Day => "Day", - Self::DateMillisecond => "DateMillisecond", - }; - serializer.serialize_str(variant) + use serde::ser::SerializeStruct; + let mut len = 0; + if self.base_conf.is_some() { + len += 1; + } + if self.has_header { + len += 1; + } + if !self.delimiter.is_empty() { + len += 1; + } + if !self.quote.is_empty() { + len += 1; + } + if self.optional_escape.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CsvScanExecNode", len)?; + if let Some(v) = self.base_conf.as_ref() { + struct_ser.serialize_field("baseConf", v)?; + } + if self.has_header { + struct_ser.serialize_field("hasHeader", &self.has_header)?; + } + if !self.delimiter.is_empty() { + struct_ser.serialize_field("delimiter", &self.delimiter)?; + } + if !self.quote.is_empty() { + struct_ser.serialize_field("quote", &self.quote)?; + } + if let Some(v) = self.optional_escape.as_ref() { + match v { + csv_scan_exec_node::OptionalEscape::Escape(v) => { + struct_ser.serialize_field("escape", v)?; + } + } + } + struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for DateUnit { +impl<'de> serde::Deserialize<'de> for CsvScanExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "Day", - "DateMillisecond", + "base_conf", + "baseConf", + "has_header", + "hasHeader", + "delimiter", + "quote", + "escape", ]; - struct GeneratedVisitor; + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + BaseConf, + HasHeader, + Delimiter, + Quote, + Escape, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = DateUnit; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), + "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), + "delimiter" => Ok(GeneratedField::Delimiter), + "quote" => Ok(GeneratedField::Quote), + "escape" => Ok(GeneratedField::Escape), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CsvScanExecNode; - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CsvScanExecNode") } - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, { - match value { - "Day" => Ok(DateUnit::Day), - "DateMillisecond" => Ok(DateUnit::DateMillisecond), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + let mut base_conf__ = None; + let mut has_header__ = None; + let mut delimiter__ = None; + let mut quote__ = None; + let mut optional_escape__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::BaseConf => { + if base_conf__.is_some() { + return Err(serde::de::Error::duplicate_field("baseConf")); + } + base_conf__ = map_.next_value()?; + } + GeneratedField::HasHeader => { + if has_header__.is_some() { + return Err(serde::de::Error::duplicate_field("hasHeader")); + } + has_header__ = Some(map_.next_value()?); + } + GeneratedField::Delimiter => { + if delimiter__.is_some() { + return Err(serde::de::Error::duplicate_field("delimiter")); + } + delimiter__ = Some(map_.next_value()?); + } + GeneratedField::Quote => { + if quote__.is_some() { + return Err(serde::de::Error::duplicate_field("quote")); + } + quote__ = Some(map_.next_value()?); + } + GeneratedField::Escape => { + if optional_escape__.is_some() { + return Err(serde::de::Error::duplicate_field("escape")); + } + optional_escape__ = map_.next_value::<::std::option::Option<_>>()?.map(csv_scan_exec_node::OptionalEscape::Escape); + } + } } + Ok(CsvScanExecNode { + base_conf: base_conf__, + has_header: has_header__.unwrap_or_default(), + delimiter: delimiter__.unwrap_or_default(), + quote: quote__.unwrap_or_default(), + optional_escape: optional_escape__, + }) } } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CsvScanExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Decimal { +impl serde::Serialize for CsvSink { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4987,37 +5159,29 @@ impl serde::Serialize for Decimal { { use serde::ser::SerializeStruct; let mut len = 0; - if self.precision != 0 { + if self.config.is_some() { len += 1; } - if self.scale != 0 { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.Decimal", len)?; - if self.precision != 0 { - struct_ser.serialize_field("precision", &self.precision)?; - } - if self.scale != 0 { - struct_ser.serialize_field("scale", &self.scale)?; + let mut struct_ser = serializer.serialize_struct("datafusion.CsvSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Decimal { +impl<'de> serde::Deserialize<'de> for CsvSink { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "precision", - "scale", + "config", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Precision, - Scale, + Config, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5039,8 +5203,7 @@ impl<'de> serde::Deserialize<'de> for Decimal { E: serde::de::Error, { match value { - "precision" => Ok(GeneratedField::Precision), - "scale" => Ok(GeneratedField::Scale), + "config" => Ok(GeneratedField::Config), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5050,48 +5213,36 @@ impl<'de> serde::Deserialize<'de> for Decimal { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Decimal; + type Value = CsvSink; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Decimal") + formatter.write_str("struct datafusion.CsvSink") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut precision__ = None; - let mut scale__ = None; + let mut config__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Precision => { - if precision__.is_some() { - return Err(serde::de::Error::duplicate_field("precision")); + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); } - precision__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::Scale => { - if scale__.is_some() { - return Err(serde::de::Error::duplicate_field("scale")); - } - scale__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + config__ = map_.next_value()?; } } } - Ok(Decimal { - precision: precision__.unwrap_or_default(), - scale: scale__.unwrap_or_default(), + Ok(CsvSink { + config: config__, }) } } - deserializer.deserialize_struct("datafusion.Decimal", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CsvSink", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Decimal128 { +impl serde::Serialize for CsvSinkExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -5099,48 +5250,55 @@ impl serde::Serialize for Decimal128 { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.value.is_empty() { + if self.input.is_some() { len += 1; } - if self.p != 0 { + if self.sink.is_some() { len += 1; } - if self.s != 0 { + if self.sink_schema.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Decimal128", len)?; - if !self.value.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; + if self.sort_order.is_some() { + len += 1; } - if self.p != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; + let mut struct_ser = serializer.serialize_struct("datafusion.CsvSinkExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } - if self.s != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; + } + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; + } + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Decimal128 { +impl<'de> serde::Deserialize<'de> for CsvSinkExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "value", - "p", - "s", + "input", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Value, - P, - S, + Input, + Sink, + SinkSchema, + SortOrder, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5162,9 +5320,10 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { E: serde::de::Error, { match value { - "value" => Ok(GeneratedField::Value), - "p" => Ok(GeneratedField::P), - "s" => Ok(GeneratedField::S), + "input" => Ok(GeneratedField::Input), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5174,58 +5333,60 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Decimal128; + type Value = CsvSinkExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Decimal128") + formatter.write_str("struct datafusion.CsvSinkExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut value__ = None; - let mut p__ = None; - let mut s__ = None; + let mut input__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("value")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - value__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) - ; + input__ = map_.next_value()?; } - GeneratedField::P => { - if p__.is_some() { - return Err(serde::de::Error::duplicate_field("p")); + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); } - p__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + sink__ = map_.next_value()?; } - GeneratedField::S => { - if s__.is_some() { - return Err(serde::de::Error::duplicate_field("s")); + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); } - s__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + sink_schema__ = map_.next_value()?; + } + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); + } + sort_order__ = map_.next_value()?; } } } - Ok(Decimal128 { - value: value__.unwrap_or_default(), - p: p__.unwrap_or_default(), - s: s__.unwrap_or_default(), + Ok(CsvSinkExecNode { + input: input__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, }) } } - deserializer.deserialize_struct("datafusion.Decimal128", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CsvSinkExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Decimal256 { +impl serde::Serialize for CsvWriterOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -5233,48 +5394,93 @@ impl serde::Serialize for Decimal256 { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.value.is_empty() { + if self.compression != 0 { len += 1; } - if self.p != 0 { + if !self.delimiter.is_empty() { len += 1; } - if self.s != 0 { + if self.has_header { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Decimal256", len)?; - if !self.value.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; + if !self.date_format.is_empty() { + len += 1; } - if self.p != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; + if !self.datetime_format.is_empty() { + len += 1; } - if self.s != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; + if !self.timestamp_format.is_empty() { + len += 1; + } + if !self.time_format.is_empty() { + len += 1; + } + if !self.null_value.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CsvWriterOptions", len)?; + if self.compression != 0 { + let v = CompressionTypeVariant::try_from(self.compression) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; + struct_ser.serialize_field("compression", &v)?; + } + if !self.delimiter.is_empty() { + struct_ser.serialize_field("delimiter", &self.delimiter)?; + } + if self.has_header { + struct_ser.serialize_field("hasHeader", &self.has_header)?; + } + if !self.date_format.is_empty() { + struct_ser.serialize_field("dateFormat", &self.date_format)?; + } + if !self.datetime_format.is_empty() { + struct_ser.serialize_field("datetimeFormat", &self.datetime_format)?; + } + if !self.timestamp_format.is_empty() { + struct_ser.serialize_field("timestampFormat", &self.timestamp_format)?; + } + if !self.time_format.is_empty() { + struct_ser.serialize_field("timeFormat", &self.time_format)?; + } + if !self.null_value.is_empty() { + struct_ser.serialize_field("nullValue", &self.null_value)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Decimal256 { +impl<'de> serde::Deserialize<'de> for CsvWriterOptions { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "value", - "p", - "s", + "compression", + "delimiter", + "has_header", + "hasHeader", + "date_format", + "dateFormat", + "datetime_format", + "datetimeFormat", + "timestamp_format", + "timestampFormat", + "time_format", + "timeFormat", + "null_value", + "nullValue", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Value, - P, - S, + Compression, + Delimiter, + HasHeader, + DateFormat, + DatetimeFormat, + TimestampFormat, + TimeFormat, + NullValue, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5296,9 +5502,14 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { E: serde::de::Error, { match value { - "value" => Ok(GeneratedField::Value), - "p" => Ok(GeneratedField::P), - "s" => Ok(GeneratedField::S), + "compression" => Ok(GeneratedField::Compression), + "delimiter" => Ok(GeneratedField::Delimiter), + "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), + "dateFormat" | "date_format" => Ok(GeneratedField::DateFormat), + "datetimeFormat" | "datetime_format" => Ok(GeneratedField::DatetimeFormat), + "timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat), + "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), + "nullValue" | "null_value" => Ok(GeneratedField::NullValue), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5308,58 +5519,92 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Decimal256; + type Value = CsvWriterOptions; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Decimal256") + formatter.write_str("struct datafusion.CsvWriterOptions") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut value__ = None; - let mut p__ = None; - let mut s__ = None; + let mut compression__ = None; + let mut delimiter__ = None; + let mut has_header__ = None; + let mut date_format__ = None; + let mut datetime_format__ = None; + let mut timestamp_format__ = None; + let mut time_format__ = None; + let mut null_value__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("value")); + GeneratedField::Compression => { + if compression__.is_some() { + return Err(serde::de::Error::duplicate_field("compression")); } - value__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) - ; + compression__ = Some(map_.next_value::()? as i32); } - GeneratedField::P => { - if p__.is_some() { - return Err(serde::de::Error::duplicate_field("p")); + GeneratedField::Delimiter => { + if delimiter__.is_some() { + return Err(serde::de::Error::duplicate_field("delimiter")); } - p__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + delimiter__ = Some(map_.next_value()?); } - GeneratedField::S => { - if s__.is_some() { - return Err(serde::de::Error::duplicate_field("s")); + GeneratedField::HasHeader => { + if has_header__.is_some() { + return Err(serde::de::Error::duplicate_field("hasHeader")); } - s__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + has_header__ = Some(map_.next_value()?); + } + GeneratedField::DateFormat => { + if date_format__.is_some() { + return Err(serde::de::Error::duplicate_field("dateFormat")); + } + date_format__ = Some(map_.next_value()?); + } + GeneratedField::DatetimeFormat => { + if datetime_format__.is_some() { + return Err(serde::de::Error::duplicate_field("datetimeFormat")); + } + datetime_format__ = Some(map_.next_value()?); + } + GeneratedField::TimestampFormat => { + if timestamp_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timestampFormat")); + } + timestamp_format__ = Some(map_.next_value()?); + } + GeneratedField::TimeFormat => { + if time_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timeFormat")); + } + time_format__ = Some(map_.next_value()?); + } + GeneratedField::NullValue => { + if null_value__.is_some() { + return Err(serde::de::Error::duplicate_field("nullValue")); + } + null_value__ = Some(map_.next_value()?); } } } - Ok(Decimal256 { - value: value__.unwrap_or_default(), - p: p__.unwrap_or_default(), - s: s__.unwrap_or_default(), + Ok(CsvWriterOptions { + compression: compression__.unwrap_or_default(), + delimiter: delimiter__.unwrap_or_default(), + has_header: has_header__.unwrap_or_default(), + date_format: date_format__.unwrap_or_default(), + datetime_format: datetime_format__.unwrap_or_default(), + timestamp_format: timestamp_format__.unwrap_or_default(), + time_format: time_format__.unwrap_or_default(), + null_value: null_value__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.Decimal256", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CsvWriterOptions", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for DfField { +impl serde::Serialize for CubeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -5367,37 +5612,29 @@ impl serde::Serialize for DfField { { use serde::ser::SerializeStruct; let mut len = 0; - if self.field.is_some() { - len += 1; - } - if self.qualifier.is_some() { + if !self.expr.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.DfField", len)?; - if let Some(v) = self.field.as_ref() { - struct_ser.serialize_field("field", v)?; - } - if let Some(v) = self.qualifier.as_ref() { - struct_ser.serialize_field("qualifier", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.CubeNode", len)?; + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for DfField { +impl<'de> serde::Deserialize<'de> for CubeNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "field", - "qualifier", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Field, - Qualifier, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5419,8 +5656,7 @@ impl<'de> serde::Deserialize<'de> for DfField { E: serde::de::Error, { match value { - "field" => Ok(GeneratedField::Field), - "qualifier" => Ok(GeneratedField::Qualifier), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5430,44 +5666,36 @@ impl<'de> serde::Deserialize<'de> for DfField { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = DfField; + type Value = CubeNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.DfField") + formatter.write_str("struct datafusion.CubeNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut field__ = None; - let mut qualifier__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Field => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("field")); - } - field__ = map_.next_value()?; - } - GeneratedField::Qualifier => { - if qualifier__.is_some() { - return Err(serde::de::Error::duplicate_field("qualifier")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - qualifier__ = map_.next_value()?; + expr__ = Some(map_.next_value()?); } } } - Ok(DfField { - field: field__, - qualifier: qualifier__, + Ok(CubeNode { + expr: expr__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.DfField", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CubeNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for DfSchema { +impl serde::Serialize for CustomTableScanNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -5475,37 +5703,64 @@ impl serde::Serialize for DfSchema { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.columns.is_empty() { + if self.table_name.is_some() { len += 1; } - if !self.metadata.is_empty() { + if self.projection.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.DfSchema", len)?; - if !self.columns.is_empty() { - struct_ser.serialize_field("columns", &self.columns)?; + if self.schema.is_some() { + len += 1; } - if !self.metadata.is_empty() { - struct_ser.serialize_field("metadata", &self.metadata)?; + if !self.filters.is_empty() { + len += 1; + } + if !self.custom_table_data.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CustomTableScanNode", len)?; + if let Some(v) = self.table_name.as_ref() { + struct_ser.serialize_field("tableName", v)?; + } + if let Some(v) = self.projection.as_ref() { + struct_ser.serialize_field("projection", v)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + if !self.filters.is_empty() { + struct_ser.serialize_field("filters", &self.filters)?; + } + if !self.custom_table_data.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("customTableData", pbjson::private::base64::encode(&self.custom_table_data).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for DfSchema { +impl<'de> serde::Deserialize<'de> for CustomTableScanNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "columns", - "metadata", + "table_name", + "tableName", + "projection", + "schema", + "filters", + "custom_table_data", + "customTableData", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Columns, - Metadata, + TableName, + Projection, + Schema, + Filters, + CustomTableData, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5527,8 +5782,11 @@ impl<'de> serde::Deserialize<'de> for DfSchema { E: serde::de::Error, { match value { - "columns" => Ok(GeneratedField::Columns), - "metadata" => Ok(GeneratedField::Metadata), + "tableName" | "table_name" => Ok(GeneratedField::TableName), + "projection" => Ok(GeneratedField::Projection), + "schema" => Ok(GeneratedField::Schema), + "filters" => Ok(GeneratedField::Filters), + "customTableData" | "custom_table_data" => Ok(GeneratedField::CustomTableData), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5538,46 +5796,141 @@ impl<'de> serde::Deserialize<'de> for DfSchema { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = DfSchema; + type Value = CustomTableScanNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.DfSchema") + formatter.write_str("struct datafusion.CustomTableScanNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut columns__ = None; - let mut metadata__ = None; + let mut table_name__ = None; + let mut projection__ = None; + let mut schema__ = None; + let mut filters__ = None; + let mut custom_table_data__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Columns => { - if columns__.is_some() { - return Err(serde::de::Error::duplicate_field("columns")); + GeneratedField::TableName => { + if table_name__.is_some() { + return Err(serde::de::Error::duplicate_field("tableName")); } - columns__ = Some(map_.next_value()?); + table_name__ = map_.next_value()?; } - GeneratedField::Metadata => { - if metadata__.is_some() { - return Err(serde::de::Error::duplicate_field("metadata")); + GeneratedField::Projection => { + if projection__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); } - metadata__ = Some( - map_.next_value::>()? - ); + projection__ = map_.next_value()?; + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + GeneratedField::Filters => { + if filters__.is_some() { + return Err(serde::de::Error::duplicate_field("filters")); + } + filters__ = Some(map_.next_value()?); + } + GeneratedField::CustomTableData => { + if custom_table_data__.is_some() { + return Err(serde::de::Error::duplicate_field("customTableData")); + } + custom_table_data__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; } } } - Ok(DfSchema { - columns: columns__.unwrap_or_default(), - metadata: metadata__.unwrap_or_default(), + Ok(CustomTableScanNode { + table_name: table_name__, + projection: projection__, + schema: schema__, + filters: filters__.unwrap_or_default(), + custom_table_data: custom_table_data__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.DfSchema", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CustomTableScanNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Dictionary { +impl serde::Serialize for DateUnit { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Day => "Day", + Self::DateMillisecond => "DateMillisecond", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for DateUnit { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "Day", + "DateMillisecond", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = DateUnit; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "Day" => Ok(DateUnit::Day), + "DateMillisecond" => Ok(DateUnit::DateMillisecond), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for Decimal { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -5585,37 +5938,37 @@ impl serde::Serialize for Dictionary { { use serde::ser::SerializeStruct; let mut len = 0; - if self.key.is_some() { + if self.precision != 0 { len += 1; } - if self.value.is_some() { + if self.scale != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Dictionary", len)?; - if let Some(v) = self.key.as_ref() { - struct_ser.serialize_field("key", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.Decimal", len)?; + if self.precision != 0 { + struct_ser.serialize_field("precision", &self.precision)?; } - if let Some(v) = self.value.as_ref() { - struct_ser.serialize_field("value", v)?; + if self.scale != 0 { + struct_ser.serialize_field("scale", &self.scale)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Dictionary { +impl<'de> serde::Deserialize<'de> for Decimal { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "key", - "value", + "precision", + "scale", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Key, - Value, + Precision, + Scale, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5637,8 +5990,8 @@ impl<'de> serde::Deserialize<'de> for Dictionary { E: serde::de::Error, { match value { - "key" => Ok(GeneratedField::Key), - "value" => Ok(GeneratedField::Value), + "precision" => Ok(GeneratedField::Precision), + "scale" => Ok(GeneratedField::Scale), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5648,44 +6001,48 @@ impl<'de> serde::Deserialize<'de> for Dictionary { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Dictionary; + type Value = Decimal; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Dictionary") + formatter.write_str("struct datafusion.Decimal") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut key__ = None; - let mut value__ = None; + let mut precision__ = None; + let mut scale__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Key => { - if key__.is_some() { - return Err(serde::de::Error::duplicate_field("key")); + GeneratedField::Precision => { + if precision__.is_some() { + return Err(serde::de::Error::duplicate_field("precision")); } - key__ = map_.next_value()?; + precision__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } - GeneratedField::Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("value")); + GeneratedField::Scale => { + if scale__.is_some() { + return Err(serde::de::Error::duplicate_field("scale")); } - value__ = map_.next_value()?; + scale__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(Dictionary { - key: key__, - value: value__, + Ok(Decimal { + precision: precision__.unwrap_or_default(), + scale: scale__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.Dictionary", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.Decimal", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for DistinctNode { +impl serde::Serialize for Decimal128 { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -5693,29 +6050,48 @@ impl serde::Serialize for DistinctNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if !self.value.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.DistinctNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + if self.p != 0 { + len += 1; + } + if self.s != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.Decimal128", len)?; + if !self.value.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; + } + if self.p != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; + } + if self.s != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for DistinctNode { +impl<'de> serde::Deserialize<'de> for Decimal128 { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", + "value", + "p", + "s", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, + Value, + P, + S, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5737,7 +6113,9 @@ impl<'de> serde::Deserialize<'de> for DistinctNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), + "value" => Ok(GeneratedField::Value), + "p" => Ok(GeneratedField::P), + "s" => Ok(GeneratedField::S), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5747,36 +6125,58 @@ impl<'de> serde::Deserialize<'de> for DistinctNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = DistinctNode; + type Value = Decimal128; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.DistinctNode") + formatter.write_str("struct datafusion.Decimal128") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; + let mut value__ = None; + let mut p__ = None; + let mut s__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); } - input__ = map_.next_value()?; + value__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::P => { + if p__.is_some() { + return Err(serde::de::Error::duplicate_field("p")); + } + p__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::S => { + if s__.is_some() { + return Err(serde::de::Error::duplicate_field("s")); + } + s__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(DistinctNode { - input: input__, + Ok(Decimal128 { + value: value__.unwrap_or_default(), + p: p__.unwrap_or_default(), + s: s__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.DistinctNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.Decimal128", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for DropViewNode { +impl serde::Serialize for Decimal256 { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -5784,46 +6184,48 @@ impl serde::Serialize for DropViewNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.name.is_some() { + if !self.value.is_empty() { len += 1; } - if self.if_exists { + if self.p != 0 { len += 1; } - if self.schema.is_some() { + if self.s != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.DropViewNode", len)?; - if let Some(v) = self.name.as_ref() { - struct_ser.serialize_field("name", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.Decimal256", len)?; + if !self.value.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; } - if self.if_exists { - struct_ser.serialize_field("ifExists", &self.if_exists)?; + if self.p != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + if self.s != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for DropViewNode { +impl<'de> serde::Deserialize<'de> for Decimal256 { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", - "if_exists", - "ifExists", - "schema", + "value", + "p", + "s", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, - IfExists, - Schema, + Value, + P, + S, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5845,9 +6247,9 @@ impl<'de> serde::Deserialize<'de> for DropViewNode { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), - "ifExists" | "if_exists" => Ok(GeneratedField::IfExists), - "schema" => Ok(GeneratedField::Schema), + "value" => Ok(GeneratedField::Value), + "p" => Ok(GeneratedField::P), + "s" => Ok(GeneratedField::S), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5857,52 +6259,58 @@ impl<'de> serde::Deserialize<'de> for DropViewNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = DropViewNode; + type Value = Decimal256; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.DropViewNode") + formatter.write_str("struct datafusion.Decimal256") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; - let mut if_exists__ = None; - let mut schema__ = None; + let mut value__ = None; + let mut p__ = None; + let mut s__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); } - name__ = map_.next_value()?; + value__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; } - GeneratedField::IfExists => { - if if_exists__.is_some() { - return Err(serde::de::Error::duplicate_field("ifExists")); + GeneratedField::P => { + if p__.is_some() { + return Err(serde::de::Error::duplicate_field("p")); } - if_exists__ = Some(map_.next_value()?); + p__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::S => { + if s__.is_some() { + return Err(serde::de::Error::duplicate_field("s")); } - schema__ = map_.next_value()?; + s__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(DropViewNode { - name: name__, - if_exists: if_exists__.unwrap_or_default(), - schema: schema__, + Ok(Decimal256 { + value: value__.unwrap_or_default(), + p: p__.unwrap_or_default(), + s: s__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.DropViewNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.Decimal256", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for EmptyExecNode { +impl serde::Serialize for DfField { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -5910,38 +6318,37 @@ impl serde::Serialize for EmptyExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.produce_one_row { + if self.field.is_some() { len += 1; } - if self.schema.is_some() { + if self.qualifier.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.EmptyExecNode", len)?; - if self.produce_one_row { - struct_ser.serialize_field("produceOneRow", &self.produce_one_row)?; + let mut struct_ser = serializer.serialize_struct("datafusion.DfField", len)?; + if let Some(v) = self.field.as_ref() { + struct_ser.serialize_field("field", v)?; } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + if let Some(v) = self.qualifier.as_ref() { + struct_ser.serialize_field("qualifier", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for EmptyExecNode { +impl<'de> serde::Deserialize<'de> for DfField { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "produce_one_row", - "produceOneRow", - "schema", + "field", + "qualifier", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - ProduceOneRow, - Schema, + Field, + Qualifier, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5963,8 +6370,8 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { E: serde::de::Error, { match value { - "produceOneRow" | "produce_one_row" => Ok(GeneratedField::ProduceOneRow), - "schema" => Ok(GeneratedField::Schema), + "field" => Ok(GeneratedField::Field), + "qualifier" => Ok(GeneratedField::Qualifier), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5974,66 +6381,82 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = EmptyExecNode; + type Value = DfField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.EmptyExecNode") + formatter.write_str("struct datafusion.DfField") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut produce_one_row__ = None; - let mut schema__ = None; + let mut field__ = None; + let mut qualifier__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::ProduceOneRow => { - if produce_one_row__.is_some() { - return Err(serde::de::Error::duplicate_field("produceOneRow")); + GeneratedField::Field => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("field")); } - produce_one_row__ = Some(map_.next_value()?); + field__ = map_.next_value()?; } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Qualifier => { + if qualifier__.is_some() { + return Err(serde::de::Error::duplicate_field("qualifier")); } - schema__ = map_.next_value()?; + qualifier__ = map_.next_value()?; } } } - Ok(EmptyExecNode { - produce_one_row: produce_one_row__.unwrap_or_default(), - schema: schema__, + Ok(DfField { + field: field__, + qualifier: qualifier__, }) } } - deserializer.deserialize_struct("datafusion.EmptyExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.DfField", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for EmptyMessage { +impl serde::Serialize for DfSchema { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { use serde::ser::SerializeStruct; - let len = 0; - let struct_ser = serializer.serialize_struct("datafusion.EmptyMessage", len)?; - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for EmptyMessage { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ + let mut len = 0; + if !self.columns.is_empty() { + len += 1; + } + if !self.metadata.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.DfSchema", len)?; + if !self.columns.is_empty() { + struct_ser.serialize_field("columns", &self.columns)?; + } + if !self.metadata.is_empty() { + struct_ser.serialize_field("metadata", &self.metadata)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for DfSchema { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "columns", + "metadata", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { + Columns, + Metadata, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6054,7 +6477,11 @@ impl<'de> serde::Deserialize<'de> for EmptyMessage { where E: serde::de::Error, { - Err(serde::de::Error::unknown_field(value, FIELDS)) + match value { + "columns" => Ok(GeneratedField::Columns), + "metadata" => Ok(GeneratedField::Metadata), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } } } deserializer.deserialize_identifier(GeneratedVisitor) @@ -6062,27 +6489,46 @@ impl<'de> serde::Deserialize<'de> for EmptyMessage { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = EmptyMessage; + type Value = DfSchema; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.EmptyMessage") + formatter.write_str("struct datafusion.DfSchema") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - while map_.next_key::()?.is_some() { - let _ = map_.next_value::()?; + let mut columns__ = None; + let mut metadata__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Columns => { + if columns__.is_some() { + return Err(serde::de::Error::duplicate_field("columns")); + } + columns__ = Some(map_.next_value()?); + } + GeneratedField::Metadata => { + if metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("metadata")); + } + metadata__ = Some( + map_.next_value::>()? + ); + } + } } - Ok(EmptyMessage { + Ok(DfSchema { + columns: columns__.unwrap_or_default(), + metadata: metadata__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.EmptyMessage", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.DfSchema", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for EmptyRelationNode { +impl serde::Serialize for Dictionary { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -6090,30 +6536,37 @@ impl serde::Serialize for EmptyRelationNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.produce_one_row { + if self.key.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.EmptyRelationNode", len)?; - if self.produce_one_row { - struct_ser.serialize_field("produceOneRow", &self.produce_one_row)?; + if self.value.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.Dictionary", len)?; + if let Some(v) = self.key.as_ref() { + struct_ser.serialize_field("key", v)?; + } + if let Some(v) = self.value.as_ref() { + struct_ser.serialize_field("value", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for EmptyRelationNode { +impl<'de> serde::Deserialize<'de> for Dictionary { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "produce_one_row", - "produceOneRow", + "key", + "value", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - ProduceOneRow, + Key, + Value, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6135,7 +6588,8 @@ impl<'de> serde::Deserialize<'de> for EmptyRelationNode { E: serde::de::Error, { match value { - "produceOneRow" | "produce_one_row" => Ok(GeneratedField::ProduceOneRow), + "key" => Ok(GeneratedField::Key), + "value" => Ok(GeneratedField::Value), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6145,36 +6599,44 @@ impl<'de> serde::Deserialize<'de> for EmptyRelationNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = EmptyRelationNode; + type Value = Dictionary; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.EmptyRelationNode") + formatter.write_str("struct datafusion.Dictionary") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut produce_one_row__ = None; + let mut key__ = None; + let mut value__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::ProduceOneRow => { - if produce_one_row__.is_some() { - return Err(serde::de::Error::duplicate_field("produceOneRow")); + GeneratedField::Key => { + if key__.is_some() { + return Err(serde::de::Error::duplicate_field("key")); } - produce_one_row__ = Some(map_.next_value()?); + key__ = map_.next_value()?; + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = map_.next_value()?; } } } - Ok(EmptyRelationNode { - produce_one_row: produce_one_row__.unwrap_or_default(), + Ok(Dictionary { + key: key__, + value: value__, }) } } - deserializer.deserialize_struct("datafusion.EmptyRelationNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.Dictionary", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ExplainExecNode { +impl serde::Serialize for DistinctNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -6182,46 +6644,29 @@ impl serde::Serialize for ExplainExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.schema.is_some() { - len += 1; - } - if !self.stringified_plans.is_empty() { - len += 1; - } - if self.verbose { + if self.input.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ExplainExecNode", len)?; - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; - } - if !self.stringified_plans.is_empty() { - struct_ser.serialize_field("stringifiedPlans", &self.stringified_plans)?; - } - if self.verbose { - struct_ser.serialize_field("verbose", &self.verbose)?; + let mut struct_ser = serializer.serialize_struct("datafusion.DistinctNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ExplainExecNode { +impl<'de> serde::Deserialize<'de> for DistinctNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "schema", - "stringified_plans", - "stringifiedPlans", - "verbose", + "input", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Schema, - StringifiedPlans, - Verbose, + Input, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6243,9 +6688,7 @@ impl<'de> serde::Deserialize<'de> for ExplainExecNode { E: serde::de::Error, { match value { - "schema" => Ok(GeneratedField::Schema), - "stringifiedPlans" | "stringified_plans" => Ok(GeneratedField::StringifiedPlans), - "verbose" => Ok(GeneratedField::Verbose), + "input" => Ok(GeneratedField::Input), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6255,52 +6698,36 @@ impl<'de> serde::Deserialize<'de> for ExplainExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ExplainExecNode; + type Value = DistinctNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ExplainExecNode") + formatter.write_str("struct datafusion.DistinctNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut schema__ = None; - let mut stringified_plans__ = None; - let mut verbose__ = None; + let mut input__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); - } - schema__ = map_.next_value()?; - } - GeneratedField::StringifiedPlans => { - if stringified_plans__.is_some() { - return Err(serde::de::Error::duplicate_field("stringifiedPlans")); - } - stringified_plans__ = Some(map_.next_value()?); - } - GeneratedField::Verbose => { - if verbose__.is_some() { - return Err(serde::de::Error::duplicate_field("verbose")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - verbose__ = Some(map_.next_value()?); + input__ = map_.next_value()?; } } } - Ok(ExplainExecNode { - schema: schema__, - stringified_plans: stringified_plans__.unwrap_or_default(), - verbose: verbose__.unwrap_or_default(), + Ok(DistinctNode { + input: input__, }) } } - deserializer.deserialize_struct("datafusion.ExplainExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.DistinctNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ExplainNode { +impl serde::Serialize for DistinctOnNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -6308,37 +6735,56 @@ impl serde::Serialize for ExplainNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if !self.on_expr.is_empty() { len += 1; } - if self.verbose { + if !self.select_expr.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ExplainNode", len)?; + if !self.sort_expr.is_empty() { + len += 1; + } + if self.input.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.DistinctOnNode", len)?; + if !self.on_expr.is_empty() { + struct_ser.serialize_field("onExpr", &self.on_expr)?; + } + if !self.select_expr.is_empty() { + struct_ser.serialize_field("selectExpr", &self.select_expr)?; + } + if !self.sort_expr.is_empty() { + struct_ser.serialize_field("sortExpr", &self.sort_expr)?; + } if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; } - if self.verbose { - struct_ser.serialize_field("verbose", &self.verbose)?; - } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ExplainNode { +impl<'de> serde::Deserialize<'de> for DistinctOnNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "on_expr", + "onExpr", + "select_expr", + "selectExpr", + "sort_expr", + "sortExpr", "input", - "verbose", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { + OnExpr, + SelectExpr, + SortExpr, Input, - Verbose, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6360,8 +6806,10 @@ impl<'de> serde::Deserialize<'de> for ExplainNode { E: serde::de::Error, { match value { + "onExpr" | "on_expr" => Ok(GeneratedField::OnExpr), + "selectExpr" | "select_expr" => Ok(GeneratedField::SelectExpr), + "sortExpr" | "sort_expr" => Ok(GeneratedField::SortExpr), "input" => Ok(GeneratedField::Input), - "verbose" => Ok(GeneratedField::Verbose), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6371,44 +6819,60 @@ impl<'de> serde::Deserialize<'de> for ExplainNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ExplainNode; + type Value = DistinctOnNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ExplainNode") + formatter.write_str("struct datafusion.DistinctOnNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { + let mut on_expr__ = None; + let mut select_expr__ = None; + let mut sort_expr__ = None; let mut input__ = None; - let mut verbose__ = None; while let Some(k) = map_.next_key()? { match k { + GeneratedField::OnExpr => { + if on_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("onExpr")); + } + on_expr__ = Some(map_.next_value()?); + } + GeneratedField::SelectExpr => { + if select_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("selectExpr")); + } + select_expr__ = Some(map_.next_value()?); + } + GeneratedField::SortExpr => { + if sort_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("sortExpr")); + } + sort_expr__ = Some(map_.next_value()?); + } GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } input__ = map_.next_value()?; } - GeneratedField::Verbose => { - if verbose__.is_some() { - return Err(serde::de::Error::duplicate_field("verbose")); - } - verbose__ = Some(map_.next_value()?); - } } } - Ok(ExplainNode { + Ok(DistinctOnNode { + on_expr: on_expr__.unwrap_or_default(), + select_expr: select_expr__.unwrap_or_default(), + sort_expr: sort_expr__.unwrap_or_default(), input: input__, - verbose: verbose__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.ExplainNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.DistinctOnNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Field { +impl serde::Serialize for DropViewNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -6416,41 +6880,29 @@ impl serde::Serialize for Field { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.name.is_empty() { + if self.name.is_some() { len += 1; } - if self.arrow_type.is_some() { - len += 1; - } - if self.nullable { - len += 1; - } - if !self.children.is_empty() { + if self.if_exists { len += 1; } - if !self.metadata.is_empty() { + if self.schema.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Field", len)?; - if !self.name.is_empty() { - struct_ser.serialize_field("name", &self.name)?; - } - if let Some(v) = self.arrow_type.as_ref() { - struct_ser.serialize_field("arrowType", v)?; - } - if self.nullable { - struct_ser.serialize_field("nullable", &self.nullable)?; + let mut struct_ser = serializer.serialize_struct("datafusion.DropViewNode", len)?; + if let Some(v) = self.name.as_ref() { + struct_ser.serialize_field("name", v)?; } - if !self.children.is_empty() { - struct_ser.serialize_field("children", &self.children)?; + if self.if_exists { + struct_ser.serialize_field("ifExists", &self.if_exists)?; } - if !self.metadata.is_empty() { - struct_ser.serialize_field("metadata", &self.metadata)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Field { +impl<'de> serde::Deserialize<'de> for DropViewNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -6458,20 +6910,16 @@ impl<'de> serde::Deserialize<'de> for Field { { const FIELDS: &[&str] = &[ "name", - "arrow_type", - "arrowType", - "nullable", - "children", - "metadata", + "if_exists", + "ifExists", + "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Name, - ArrowType, - Nullable, - Children, - Metadata, + IfExists, + Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6494,10 +6942,8 @@ impl<'de> serde::Deserialize<'de> for Field { { match value { "name" => Ok(GeneratedField::Name), - "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), - "nullable" => Ok(GeneratedField::Nullable), - "children" => Ok(GeneratedField::Children), - "metadata" => Ok(GeneratedField::Metadata), + "ifExists" | "if_exists" => Ok(GeneratedField::IfExists), + "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6507,70 +6953,52 @@ impl<'de> serde::Deserialize<'de> for Field { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Field; + type Value = DropViewNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Field") + formatter.write_str("struct datafusion.DropViewNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut name__ = None; - let mut arrow_type__ = None; - let mut nullable__ = None; - let mut children__ = None; - let mut metadata__ = None; + let mut if_exists__ = None; + let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { if name__.is_some() { return Err(serde::de::Error::duplicate_field("name")); } - name__ = Some(map_.next_value()?); - } - GeneratedField::ArrowType => { - if arrow_type__.is_some() { - return Err(serde::de::Error::duplicate_field("arrowType")); - } - arrow_type__ = map_.next_value()?; - } - GeneratedField::Nullable => { - if nullable__.is_some() { - return Err(serde::de::Error::duplicate_field("nullable")); - } - nullable__ = Some(map_.next_value()?); + name__ = map_.next_value()?; } - GeneratedField::Children => { - if children__.is_some() { - return Err(serde::de::Error::duplicate_field("children")); + GeneratedField::IfExists => { + if if_exists__.is_some() { + return Err(serde::de::Error::duplicate_field("ifExists")); } - children__ = Some(map_.next_value()?); + if_exists__ = Some(map_.next_value()?); } - GeneratedField::Metadata => { - if metadata__.is_some() { - return Err(serde::de::Error::duplicate_field("metadata")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - metadata__ = Some( - map_.next_value::>()? - ); + schema__ = map_.next_value()?; } } } - Ok(Field { - name: name__.unwrap_or_default(), - arrow_type: arrow_type__, - nullable: nullable__.unwrap_or_default(), - children: children__.unwrap_or_default(), - metadata: metadata__.unwrap_or_default(), + Ok(DropViewNode { + name: name__, + if_exists: if_exists__.unwrap_or_default(), + schema: schema__, }) } } - deserializer.deserialize_struct("datafusion.Field", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.DropViewNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FileGroup { +impl serde::Serialize for EmptyExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -6578,29 +7006,29 @@ impl serde::Serialize for FileGroup { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.files.is_empty() { + if self.schema.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FileGroup", len)?; - if !self.files.is_empty() { - struct_ser.serialize_field("files", &self.files)?; + let mut struct_ser = serializer.serialize_struct("datafusion.EmptyExecNode", len)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FileGroup { +impl<'de> serde::Deserialize<'de> for EmptyExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "files", + "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Files, + Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6622,7 +7050,7 @@ impl<'de> serde::Deserialize<'de> for FileGroup { E: serde::de::Error, { match value { - "files" => Ok(GeneratedField::Files), + "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6632,76 +7060,58 @@ impl<'de> serde::Deserialize<'de> for FileGroup { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FileGroup; + type Value = EmptyExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FileGroup") + formatter.write_str("struct datafusion.EmptyExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut files__ = None; + let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Files => { - if files__.is_some() { - return Err(serde::de::Error::duplicate_field("files")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - files__ = Some(map_.next_value()?); + schema__ = map_.next_value()?; } } } - Ok(FileGroup { - files: files__.unwrap_or_default(), + Ok(EmptyExecNode { + schema: schema__, }) } } - deserializer.deserialize_struct("datafusion.FileGroup", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.EmptyExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FileRange { +impl serde::Serialize for EmptyMessage { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { use serde::ser::SerializeStruct; - let mut len = 0; - if self.start != 0 { - len += 1; - } - if self.end != 0 { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.FileRange", len)?; - if self.start != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("start", ToString::to_string(&self.start).as_str())?; - } - if self.end != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("end", ToString::to_string(&self.end).as_str())?; - } + let len = 0; + let struct_ser = serializer.serialize_struct("datafusion.EmptyMessage", len)?; struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FileRange { +impl<'de> serde::Deserialize<'de> for EmptyMessage { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "start", - "end", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Start, - End, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6722,11 +7132,7 @@ impl<'de> serde::Deserialize<'de> for FileRange { where E: serde::de::Error, { - match value { - "start" => Ok(GeneratedField::Start), - "end" => Ok(GeneratedField::End), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } + Err(serde::de::Error::unknown_field(value, FIELDS)) } } deserializer.deserialize_identifier(GeneratedVisitor) @@ -6734,48 +7140,27 @@ impl<'de> serde::Deserialize<'de> for FileRange { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FileRange; + type Value = EmptyMessage; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FileRange") + formatter.write_str("struct datafusion.EmptyMessage") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut start__ = None; - let mut end__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Start => { - if start__.is_some() { - return Err(serde::de::Error::duplicate_field("start")); - } - start__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::End => { - if end__.is_some() { - return Err(serde::de::Error::duplicate_field("end")); - } - end__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - } + while map_.next_key::()?.is_some() { + let _ = map_.next_value::()?; } - Ok(FileRange { - start: start__.unwrap_or_default(), - end: end__.unwrap_or_default(), + Ok(EmptyMessage { }) } } - deserializer.deserialize_struct("datafusion.FileRange", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.EmptyMessage", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FileScanExecConf { +impl serde::Serialize for EmptyRelationNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -6783,89 +7168,138 @@ impl serde::Serialize for FileScanExecConf { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.file_groups.is_empty() { - len += 1; - } - if self.schema.is_some() { - len += 1; - } - if !self.projection.is_empty() { - len += 1; - } - if self.limit.is_some() { - len += 1; - } - if self.statistics.is_some() { + if self.produce_one_row { len += 1; } - if !self.table_partition_cols.is_empty() { - len += 1; + let mut struct_ser = serializer.serialize_struct("datafusion.EmptyRelationNode", len)?; + if self.produce_one_row { + struct_ser.serialize_field("produceOneRow", &self.produce_one_row)?; } - if !self.object_store_url.is_empty() { - len += 1; + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for EmptyRelationNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "produce_one_row", + "produceOneRow", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + ProduceOneRow, } - if !self.output_ordering.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.FileScanExecConf", len)?; - if !self.file_groups.is_empty() { - struct_ser.serialize_field("fileGroups", &self.file_groups)?; + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "produceOneRow" | "produce_one_row" => Ok(GeneratedField::ProduceOneRow), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = EmptyRelationNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.EmptyRelationNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut produce_one_row__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::ProduceOneRow => { + if produce_one_row__.is_some() { + return Err(serde::de::Error::duplicate_field("produceOneRow")); + } + produce_one_row__ = Some(map_.next_value()?); + } + } + } + Ok(EmptyRelationNode { + produce_one_row: produce_one_row__.unwrap_or_default(), + }) + } } - if !self.projection.is_empty() { - struct_ser.serialize_field("projection", &self.projection)?; + deserializer.deserialize_struct("datafusion.EmptyRelationNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ExplainExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.schema.is_some() { + len += 1; } - if let Some(v) = self.limit.as_ref() { - struct_ser.serialize_field("limit", v)?; + if !self.stringified_plans.is_empty() { + len += 1; } - if let Some(v) = self.statistics.as_ref() { - struct_ser.serialize_field("statistics", v)?; + if self.verbose { + len += 1; } - if !self.table_partition_cols.is_empty() { - struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ExplainExecNode", len)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } - if !self.object_store_url.is_empty() { - struct_ser.serialize_field("objectStoreUrl", &self.object_store_url)?; + if !self.stringified_plans.is_empty() { + struct_ser.serialize_field("stringifiedPlans", &self.stringified_plans)?; } - if !self.output_ordering.is_empty() { - struct_ser.serialize_field("outputOrdering", &self.output_ordering)?; + if self.verbose { + struct_ser.serialize_field("verbose", &self.verbose)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FileScanExecConf { +impl<'de> serde::Deserialize<'de> for ExplainExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "file_groups", - "fileGroups", "schema", - "projection", - "limit", - "statistics", - "table_partition_cols", - "tablePartitionCols", - "object_store_url", - "objectStoreUrl", - "output_ordering", - "outputOrdering", + "stringified_plans", + "stringifiedPlans", + "verbose", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - FileGroups, Schema, - Projection, - Limit, - Statistics, - TablePartitionCols, - ObjectStoreUrl, - OutputOrdering, + StringifiedPlans, + Verbose, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6887,14 +7321,9 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { E: serde::de::Error, { match value { - "fileGroups" | "file_groups" => Ok(GeneratedField::FileGroups), "schema" => Ok(GeneratedField::Schema), - "projection" => Ok(GeneratedField::Projection), - "limit" => Ok(GeneratedField::Limit), - "statistics" => Ok(GeneratedField::Statistics), - "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), - "objectStoreUrl" | "object_store_url" => Ok(GeneratedField::ObjectStoreUrl), - "outputOrdering" | "output_ordering" => Ok(GeneratedField::OutputOrdering), + "stringifiedPlans" | "stringified_plans" => Ok(GeneratedField::StringifiedPlans), + "verbose" => Ok(GeneratedField::Verbose), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6904,95 +7333,52 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FileScanExecConf; + type Value = ExplainExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FileScanExecConf") + formatter.write_str("struct datafusion.ExplainExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut file_groups__ = None; let mut schema__ = None; - let mut projection__ = None; - let mut limit__ = None; - let mut statistics__ = None; - let mut table_partition_cols__ = None; - let mut object_store_url__ = None; - let mut output_ordering__ = None; + let mut stringified_plans__ = None; + let mut verbose__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::FileGroups => { - if file_groups__.is_some() { - return Err(serde::de::Error::duplicate_field("fileGroups")); - } - file_groups__ = Some(map_.next_value()?); - } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); } schema__ = map_.next_value()?; } - GeneratedField::Projection => { - if projection__.is_some() { - return Err(serde::de::Error::duplicate_field("projection")); - } - projection__ = - Some(map_.next_value::>>()? - .into_iter().map(|x| x.0).collect()) - ; - } - GeneratedField::Limit => { - if limit__.is_some() { - return Err(serde::de::Error::duplicate_field("limit")); - } - limit__ = map_.next_value()?; - } - GeneratedField::Statistics => { - if statistics__.is_some() { - return Err(serde::de::Error::duplicate_field("statistics")); - } - statistics__ = map_.next_value()?; - } - GeneratedField::TablePartitionCols => { - if table_partition_cols__.is_some() { - return Err(serde::de::Error::duplicate_field("tablePartitionCols")); - } - table_partition_cols__ = Some(map_.next_value()?); - } - GeneratedField::ObjectStoreUrl => { - if object_store_url__.is_some() { - return Err(serde::de::Error::duplicate_field("objectStoreUrl")); + GeneratedField::StringifiedPlans => { + if stringified_plans__.is_some() { + return Err(serde::de::Error::duplicate_field("stringifiedPlans")); } - object_store_url__ = Some(map_.next_value()?); + stringified_plans__ = Some(map_.next_value()?); } - GeneratedField::OutputOrdering => { - if output_ordering__.is_some() { - return Err(serde::de::Error::duplicate_field("outputOrdering")); + GeneratedField::Verbose => { + if verbose__.is_some() { + return Err(serde::de::Error::duplicate_field("verbose")); } - output_ordering__ = Some(map_.next_value()?); + verbose__ = Some(map_.next_value()?); } } } - Ok(FileScanExecConf { - file_groups: file_groups__.unwrap_or_default(), + Ok(ExplainExecNode { schema: schema__, - projection: projection__.unwrap_or_default(), - limit: limit__, - statistics: statistics__, - table_partition_cols: table_partition_cols__.unwrap_or_default(), - object_store_url: object_store_url__.unwrap_or_default(), - output_ordering: output_ordering__.unwrap_or_default(), + stringified_plans: stringified_plans__.unwrap_or_default(), + verbose: verbose__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.FileScanExecConf", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ExplainExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FilterExecNode { +impl serde::Serialize for ExplainNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7003,20 +7389,20 @@ impl serde::Serialize for FilterExecNode { if self.input.is_some() { len += 1; } - if self.expr.is_some() { + if self.verbose { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FilterExecNode", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ExplainNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; } - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if self.verbose { + struct_ser.serialize_field("verbose", &self.verbose)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FilterExecNode { +impl<'de> serde::Deserialize<'de> for ExplainNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -7024,13 +7410,13 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { { const FIELDS: &[&str] = &[ "input", - "expr", + "verbose", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, - Expr, + Verbose, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7053,7 +7439,7 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { { match value { "input" => Ok(GeneratedField::Input), - "expr" => Ok(GeneratedField::Expr), + "verbose" => Ok(GeneratedField::Verbose), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7063,18 +7449,18 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FilterExecNode; + type Value = ExplainNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FilterExecNode") + formatter.write_str("struct datafusion.ExplainNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; - let mut expr__ = None; + let mut verbose__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -7083,24 +7469,24 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { } input__ = map_.next_value()?; } - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Verbose => { + if verbose__.is_some() { + return Err(serde::de::Error::duplicate_field("verbose")); } - expr__ = map_.next_value()?; + verbose__ = Some(map_.next_value()?); } } } - Ok(FilterExecNode { + Ok(ExplainNode { input: input__, - expr: expr__, + verbose: verbose__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.FilterExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ExplainNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FixedSizeBinary { +impl serde::Serialize for Field { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7108,29 +7494,81 @@ impl serde::Serialize for FixedSizeBinary { { use serde::ser::SerializeStruct; let mut len = 0; - if self.length != 0 { + if !self.name.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FixedSizeBinary", len)?; - if self.length != 0 { - struct_ser.serialize_field("length", &self.length)?; + if self.arrow_type.is_some() { + len += 1; + } + if self.nullable { + len += 1; + } + if !self.children.is_empty() { + len += 1; + } + if !self.metadata.is_empty() { + len += 1; + } + if self.dict_id != 0 { + len += 1; + } + if self.dict_ordered { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.Field", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if let Some(v) = self.arrow_type.as_ref() { + struct_ser.serialize_field("arrowType", v)?; + } + if self.nullable { + struct_ser.serialize_field("nullable", &self.nullable)?; + } + if !self.children.is_empty() { + struct_ser.serialize_field("children", &self.children)?; + } + if !self.metadata.is_empty() { + struct_ser.serialize_field("metadata", &self.metadata)?; + } + if self.dict_id != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dictId", ToString::to_string(&self.dict_id).as_str())?; + } + if self.dict_ordered { + struct_ser.serialize_field("dictOrdered", &self.dict_ordered)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FixedSizeBinary { +impl<'de> serde::Deserialize<'de> for Field { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "length", + "name", + "arrow_type", + "arrowType", + "nullable", + "children", + "metadata", + "dict_id", + "dictId", + "dict_ordered", + "dictOrdered", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Length, + Name, + ArrowType, + Nullable, + Children, + Metadata, + DictId, + DictOrdered, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7152,7 +7590,13 @@ impl<'de> serde::Deserialize<'de> for FixedSizeBinary { E: serde::de::Error, { match value { - "length" => Ok(GeneratedField::Length), + "name" => Ok(GeneratedField::Name), + "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "nullable" => Ok(GeneratedField::Nullable), + "children" => Ok(GeneratedField::Children), + "metadata" => Ok(GeneratedField::Metadata), + "dictId" | "dict_id" => Ok(GeneratedField::DictId), + "dictOrdered" | "dict_ordered" => Ok(GeneratedField::DictOrdered), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7162,38 +7606,88 @@ impl<'de> serde::Deserialize<'de> for FixedSizeBinary { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FixedSizeBinary; + type Value = Field; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FixedSizeBinary") + formatter.write_str("struct datafusion.Field") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut length__ = None; + let mut name__ = None; + let mut arrow_type__ = None; + let mut nullable__ = None; + let mut children__ = None; + let mut metadata__ = None; + let mut dict_id__ = None; + let mut dict_ordered__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Length => { - if length__.is_some() { - return Err(serde::de::Error::duplicate_field("length")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - length__ = + name__ = Some(map_.next_value()?); + } + GeneratedField::ArrowType => { + if arrow_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowType")); + } + arrow_type__ = map_.next_value()?; + } + GeneratedField::Nullable => { + if nullable__.is_some() { + return Err(serde::de::Error::duplicate_field("nullable")); + } + nullable__ = Some(map_.next_value()?); + } + GeneratedField::Children => { + if children__.is_some() { + return Err(serde::de::Error::duplicate_field("children")); + } + children__ = Some(map_.next_value()?); + } + GeneratedField::Metadata => { + if metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("metadata")); + } + metadata__ = Some( + map_.next_value::>()? + ); + } + GeneratedField::DictId => { + if dict_id__.is_some() { + return Err(serde::de::Error::duplicate_field("dictId")); + } + dict_id__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } + GeneratedField::DictOrdered => { + if dict_ordered__.is_some() { + return Err(serde::de::Error::duplicate_field("dictOrdered")); + } + dict_ordered__ = Some(map_.next_value()?); + } } } - Ok(FixedSizeBinary { - length: length__.unwrap_or_default(), + Ok(Field { + name: name__.unwrap_or_default(), + arrow_type: arrow_type__, + nullable: nullable__.unwrap_or_default(), + children: children__.unwrap_or_default(), + metadata: metadata__.unwrap_or_default(), + dict_id: dict_id__.unwrap_or_default(), + dict_ordered: dict_ordered__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.FixedSizeBinary", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.Field", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FixedSizeList { +impl serde::Serialize for FileGroup { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7201,39 +7695,29 @@ impl serde::Serialize for FixedSizeList { { use serde::ser::SerializeStruct; let mut len = 0; - if self.field_type.is_some() { - len += 1; - } - if self.list_size != 0 { + if !self.files.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FixedSizeList", len)?; - if let Some(v) = self.field_type.as_ref() { - struct_ser.serialize_field("fieldType", v)?; - } - if self.list_size != 0 { - struct_ser.serialize_field("listSize", &self.list_size)?; + let mut struct_ser = serializer.serialize_struct("datafusion.FileGroup", len)?; + if !self.files.is_empty() { + struct_ser.serialize_field("files", &self.files)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FixedSizeList { +impl<'de> serde::Deserialize<'de> for FileGroup { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "field_type", - "fieldType", - "list_size", - "listSize", + "files", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - FieldType, - ListSize, + Files, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7255,8 +7739,7 @@ impl<'de> serde::Deserialize<'de> for FixedSizeList { E: serde::de::Error, { match value { - "fieldType" | "field_type" => Ok(GeneratedField::FieldType), - "listSize" | "list_size" => Ok(GeneratedField::ListSize), + "files" => Ok(GeneratedField::Files), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7266,46 +7749,36 @@ impl<'de> serde::Deserialize<'de> for FixedSizeList { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FixedSizeList; + type Value = FileGroup; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FixedSizeList") + formatter.write_str("struct datafusion.FileGroup") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut field_type__ = None; - let mut list_size__ = None; + let mut files__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::FieldType => { - if field_type__.is_some() { - return Err(serde::de::Error::duplicate_field("fieldType")); - } - field_type__ = map_.next_value()?; - } - GeneratedField::ListSize => { - if list_size__.is_some() { - return Err(serde::de::Error::duplicate_field("listSize")); + GeneratedField::Files => { + if files__.is_some() { + return Err(serde::de::Error::duplicate_field("files")); } - list_size__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + files__ = Some(map_.next_value()?); } } } - Ok(FixedSizeList { - field_type: field_type__, - list_size: list_size__.unwrap_or_default(), + Ok(FileGroup { + files: files__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.FixedSizeList", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FileGroup", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FullTableReference { +impl serde::Serialize for FileRange { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7313,45 +7786,39 @@ impl serde::Serialize for FullTableReference { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.catalog.is_empty() { - len += 1; - } - if !self.schema.is_empty() { + if self.start != 0 { len += 1; } - if !self.table.is_empty() { + if self.end != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FullTableReference", len)?; - if !self.catalog.is_empty() { - struct_ser.serialize_field("catalog", &self.catalog)?; - } - if !self.schema.is_empty() { - struct_ser.serialize_field("schema", &self.schema)?; + let mut struct_ser = serializer.serialize_struct("datafusion.FileRange", len)?; + if self.start != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("start", ToString::to_string(&self.start).as_str())?; } - if !self.table.is_empty() { - struct_ser.serialize_field("table", &self.table)?; + if self.end != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("end", ToString::to_string(&self.end).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FullTableReference { +impl<'de> serde::Deserialize<'de> for FileRange { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "catalog", - "schema", - "table", + "start", + "end", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Catalog, - Schema, - Table, + Start, + End, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7373,9 +7840,8 @@ impl<'de> serde::Deserialize<'de> for FullTableReference { E: serde::de::Error, { match value { - "catalog" => Ok(GeneratedField::Catalog), - "schema" => Ok(GeneratedField::Schema), - "table" => Ok(GeneratedField::Table), + "start" => Ok(GeneratedField::Start), + "end" => Ok(GeneratedField::End), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7385,52 +7851,48 @@ impl<'de> serde::Deserialize<'de> for FullTableReference { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FullTableReference; + type Value = FileRange; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FullTableReference") + formatter.write_str("struct datafusion.FileRange") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut catalog__ = None; - let mut schema__ = None; - let mut table__ = None; + let mut start__ = None; + let mut end__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Catalog => { - if catalog__.is_some() { - return Err(serde::de::Error::duplicate_field("catalog")); - } - catalog__ = Some(map_.next_value()?); - } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Start => { + if start__.is_some() { + return Err(serde::de::Error::duplicate_field("start")); } - schema__ = Some(map_.next_value()?); + start__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } - GeneratedField::Table => { - if table__.is_some() { - return Err(serde::de::Error::duplicate_field("table")); + GeneratedField::End => { + if end__.is_some() { + return Err(serde::de::Error::duplicate_field("end")); } - table__ = Some(map_.next_value()?); + end__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(FullTableReference { - catalog: catalog__.unwrap_or_default(), - schema: schema__.unwrap_or_default(), - table: table__.unwrap_or_default(), + Ok(FileRange { + start: start__.unwrap_or_default(), + end: end__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.FullTableReference", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FileRange", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for GetIndexedField { +impl serde::Serialize for FileScanExecConf { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7438,54 +7900,89 @@ impl serde::Serialize for GetIndexedField { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.file_groups.is_empty() { len += 1; } - if self.field.is_some() { + if self.schema.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.GetIndexedField", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if !self.projection.is_empty() { + len += 1; } - if let Some(v) = self.field.as_ref() { - match v { - get_indexed_field::Field::NamedStructField(v) => { - struct_ser.serialize_field("namedStructField", v)?; - } - get_indexed_field::Field::ListIndex(v) => { - struct_ser.serialize_field("listIndex", v)?; - } - get_indexed_field::Field::ListRange(v) => { - struct_ser.serialize_field("listRange", v)?; - } - } + if self.limit.is_some() { + len += 1; + } + if self.statistics.is_some() { + len += 1; + } + if !self.table_partition_cols.is_empty() { + len += 1; + } + if !self.object_store_url.is_empty() { + len += 1; + } + if !self.output_ordering.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FileScanExecConf", len)?; + if !self.file_groups.is_empty() { + struct_ser.serialize_field("fileGroups", &self.file_groups)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + if !self.projection.is_empty() { + struct_ser.serialize_field("projection", &self.projection)?; + } + if let Some(v) = self.limit.as_ref() { + struct_ser.serialize_field("limit", v)?; + } + if let Some(v) = self.statistics.as_ref() { + struct_ser.serialize_field("statistics", v)?; + } + if !self.table_partition_cols.is_empty() { + struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; + } + if !self.object_store_url.is_empty() { + struct_ser.serialize_field("objectStoreUrl", &self.object_store_url)?; + } + if !self.output_ordering.is_empty() { + struct_ser.serialize_field("outputOrdering", &self.output_ordering)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for GetIndexedField { +impl<'de> serde::Deserialize<'de> for FileScanExecConf { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "named_struct_field", - "namedStructField", - "list_index", - "listIndex", - "list_range", - "listRange", + "file_groups", + "fileGroups", + "schema", + "projection", + "limit", + "statistics", + "table_partition_cols", + "tablePartitionCols", + "object_store_url", + "objectStoreUrl", + "output_ordering", + "outputOrdering", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - NamedStructField, - ListIndex, - ListRange, + FileGroups, + Schema, + Projection, + Limit, + Statistics, + TablePartitionCols, + ObjectStoreUrl, + OutputOrdering, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7507,10 +8004,14 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "namedStructField" | "named_struct_field" => Ok(GeneratedField::NamedStructField), - "listIndex" | "list_index" => Ok(GeneratedField::ListIndex), - "listRange" | "list_range" => Ok(GeneratedField::ListRange), + "fileGroups" | "file_groups" => Ok(GeneratedField::FileGroups), + "schema" => Ok(GeneratedField::Schema), + "projection" => Ok(GeneratedField::Projection), + "limit" => Ok(GeneratedField::Limit), + "statistics" => Ok(GeneratedField::Statistics), + "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), + "objectStoreUrl" | "object_store_url" => Ok(GeneratedField::ObjectStoreUrl), + "outputOrdering" | "output_ordering" => Ok(GeneratedField::OutputOrdering), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7520,59 +8021,95 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GetIndexedField; + type Value = FileScanExecConf; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.GetIndexedField") + formatter.write_str("struct datafusion.FileScanExecConf") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; - let mut field__ = None; + let mut file_groups__ = None; + let mut schema__ = None; + let mut projection__ = None; + let mut limit__ = None; + let mut statistics__ = None; + let mut table_partition_cols__ = None; + let mut object_store_url__ = None; + let mut output_ordering__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::FileGroups => { + if file_groups__.is_some() { + return Err(serde::de::Error::duplicate_field("fileGroups")); } - expr__ = map_.next_value()?; + file_groups__ = Some(map_.next_value()?); } - GeneratedField::NamedStructField => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("namedStructField")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::NamedStructField) -; + schema__ = map_.next_value()?; } - GeneratedField::ListIndex => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("listIndex")); + GeneratedField::Projection => { + if projection__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListIndex) -; + projection__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; } - GeneratedField::ListRange => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("listRange")); + GeneratedField::Limit => { + if limit__.is_some() { + return Err(serde::de::Error::duplicate_field("limit")); } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListRange) -; + limit__ = map_.next_value()?; + } + GeneratedField::Statistics => { + if statistics__.is_some() { + return Err(serde::de::Error::duplicate_field("statistics")); + } + statistics__ = map_.next_value()?; + } + GeneratedField::TablePartitionCols => { + if table_partition_cols__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePartitionCols")); + } + table_partition_cols__ = Some(map_.next_value()?); + } + GeneratedField::ObjectStoreUrl => { + if object_store_url__.is_some() { + return Err(serde::de::Error::duplicate_field("objectStoreUrl")); + } + object_store_url__ = Some(map_.next_value()?); + } + GeneratedField::OutputOrdering => { + if output_ordering__.is_some() { + return Err(serde::de::Error::duplicate_field("outputOrdering")); + } + output_ordering__ = Some(map_.next_value()?); } } } - Ok(GetIndexedField { - expr: expr__, - field: field__, + Ok(FileScanExecConf { + file_groups: file_groups__.unwrap_or_default(), + schema: schema__, + projection: projection__.unwrap_or_default(), + limit: limit__, + statistics: statistics__, + table_partition_cols: table_partition_cols__.unwrap_or_default(), + object_store_url: object_store_url__.unwrap_or_default(), + output_ordering: output_ordering__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.GetIndexedField", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FileScanExecConf", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for GlobalLimitExecNode { +impl serde::Serialize for FileSinkConfig { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7580,46 +8117,92 @@ impl serde::Serialize for GlobalLimitExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if !self.object_store_url.is_empty() { len += 1; } - if self.skip != 0 { + if !self.file_groups.is_empty() { len += 1; } - if self.fetch != 0 { + if !self.table_paths.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.GlobalLimitExecNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + if self.output_schema.is_some() { + len += 1; } - if self.skip != 0 { - struct_ser.serialize_field("skip", &self.skip)?; + if !self.table_partition_cols.is_empty() { + len += 1; } - if self.fetch != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; + if self.single_file_output { + len += 1; + } + if self.overwrite { + len += 1; + } + if self.file_type_writer_options.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FileSinkConfig", len)?; + if !self.object_store_url.is_empty() { + struct_ser.serialize_field("objectStoreUrl", &self.object_store_url)?; + } + if !self.file_groups.is_empty() { + struct_ser.serialize_field("fileGroups", &self.file_groups)?; + } + if !self.table_paths.is_empty() { + struct_ser.serialize_field("tablePaths", &self.table_paths)?; + } + if let Some(v) = self.output_schema.as_ref() { + struct_ser.serialize_field("outputSchema", v)?; + } + if !self.table_partition_cols.is_empty() { + struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; + } + if self.single_file_output { + struct_ser.serialize_field("singleFileOutput", &self.single_file_output)?; + } + if self.overwrite { + struct_ser.serialize_field("overwrite", &self.overwrite)?; + } + if let Some(v) = self.file_type_writer_options.as_ref() { + struct_ser.serialize_field("fileTypeWriterOptions", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { +impl<'de> serde::Deserialize<'de> for FileSinkConfig { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "skip", - "fetch", + "object_store_url", + "objectStoreUrl", + "file_groups", + "fileGroups", + "table_paths", + "tablePaths", + "output_schema", + "outputSchema", + "table_partition_cols", + "tablePartitionCols", + "single_file_output", + "singleFileOutput", + "overwrite", + "file_type_writer_options", + "fileTypeWriterOptions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - Skip, - Fetch, + ObjectStoreUrl, + FileGroups, + TablePaths, + OutputSchema, + TablePartitionCols, + SingleFileOutput, + Overwrite, + FileTypeWriterOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7641,9 +8224,14 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "skip" => Ok(GeneratedField::Skip), - "fetch" => Ok(GeneratedField::Fetch), + "objectStoreUrl" | "object_store_url" => Ok(GeneratedField::ObjectStoreUrl), + "fileGroups" | "file_groups" => Ok(GeneratedField::FileGroups), + "tablePaths" | "table_paths" => Ok(GeneratedField::TablePaths), + "outputSchema" | "output_schema" => Ok(GeneratedField::OutputSchema), + "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), + "singleFileOutput" | "single_file_output" => Ok(GeneratedField::SingleFileOutput), + "overwrite" => Ok(GeneratedField::Overwrite), + "fileTypeWriterOptions" | "file_type_writer_options" => Ok(GeneratedField::FileTypeWriterOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7653,56 +8241,92 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GlobalLimitExecNode; + type Value = FileSinkConfig; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.GlobalLimitExecNode") + formatter.write_str("struct datafusion.FileSinkConfig") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut skip__ = None; - let mut fetch__ = None; + let mut object_store_url__ = None; + let mut file_groups__ = None; + let mut table_paths__ = None; + let mut output_schema__ = None; + let mut table_partition_cols__ = None; + let mut single_file_output__ = None; + let mut overwrite__ = None; + let mut file_type_writer_options__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::ObjectStoreUrl => { + if object_store_url__.is_some() { + return Err(serde::de::Error::duplicate_field("objectStoreUrl")); } - input__ = map_.next_value()?; + object_store_url__ = Some(map_.next_value()?); } - GeneratedField::Skip => { - if skip__.is_some() { - return Err(serde::de::Error::duplicate_field("skip")); + GeneratedField::FileGroups => { + if file_groups__.is_some() { + return Err(serde::de::Error::duplicate_field("fileGroups")); } - skip__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + file_groups__ = Some(map_.next_value()?); } - GeneratedField::Fetch => { - if fetch__.is_some() { - return Err(serde::de::Error::duplicate_field("fetch")); + GeneratedField::TablePaths => { + if table_paths__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePaths")); } - fetch__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + table_paths__ = Some(map_.next_value()?); + } + GeneratedField::OutputSchema => { + if output_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("outputSchema")); + } + output_schema__ = map_.next_value()?; + } + GeneratedField::TablePartitionCols => { + if table_partition_cols__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePartitionCols")); + } + table_partition_cols__ = Some(map_.next_value()?); + } + GeneratedField::SingleFileOutput => { + if single_file_output__.is_some() { + return Err(serde::de::Error::duplicate_field("singleFileOutput")); + } + single_file_output__ = Some(map_.next_value()?); + } + GeneratedField::Overwrite => { + if overwrite__.is_some() { + return Err(serde::de::Error::duplicate_field("overwrite")); + } + overwrite__ = Some(map_.next_value()?); + } + GeneratedField::FileTypeWriterOptions => { + if file_type_writer_options__.is_some() { + return Err(serde::de::Error::duplicate_field("fileTypeWriterOptions")); + } + file_type_writer_options__ = map_.next_value()?; } } } - Ok(GlobalLimitExecNode { - input: input__, - skip: skip__.unwrap_or_default(), - fetch: fetch__.unwrap_or_default(), + Ok(FileSinkConfig { + object_store_url: object_store_url__.unwrap_or_default(), + file_groups: file_groups__.unwrap_or_default(), + table_paths: table_paths__.unwrap_or_default(), + output_schema: output_schema__, + table_partition_cols: table_partition_cols__.unwrap_or_default(), + single_file_output: single_file_output__.unwrap_or_default(), + overwrite: overwrite__.unwrap_or_default(), + file_type_writer_options: file_type_writer_options__, }) } } - deserializer.deserialize_struct("datafusion.GlobalLimitExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FileSinkConfig", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for GroupingSetNode { +impl serde::Serialize for FileTypeWriterOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7710,29 +8334,46 @@ impl serde::Serialize for GroupingSetNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.expr.is_empty() { + if self.file_type.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.GroupingSetNode", len)?; - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; + let mut struct_ser = serializer.serialize_struct("datafusion.FileTypeWriterOptions", len)?; + if let Some(v) = self.file_type.as_ref() { + match v { + file_type_writer_options::FileType::JsonOptions(v) => { + struct_ser.serialize_field("jsonOptions", v)?; + } + file_type_writer_options::FileType::ParquetOptions(v) => { + struct_ser.serialize_field("parquetOptions", v)?; + } + file_type_writer_options::FileType::CsvOptions(v) => { + struct_ser.serialize_field("csvOptions", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for GroupingSetNode { +impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "json_options", + "jsonOptions", + "parquet_options", + "parquetOptions", + "csv_options", + "csvOptions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + JsonOptions, + ParquetOptions, + CsvOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7754,7 +8395,9 @@ impl<'de> serde::Deserialize<'de> for GroupingSetNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "jsonOptions" | "json_options" => Ok(GeneratedField::JsonOptions), + "parquetOptions" | "parquet_options" => Ok(GeneratedField::ParquetOptions), + "csvOptions" | "csv_options" => Ok(GeneratedField::CsvOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7764,36 +8407,51 @@ impl<'de> serde::Deserialize<'de> for GroupingSetNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GroupingSetNode; + type Value = FileTypeWriterOptions; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.GroupingSetNode") + formatter.write_str("struct datafusion.FileTypeWriterOptions") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut file_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::JsonOptions => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("jsonOptions")); } - expr__ = Some(map_.next_value()?); + file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::JsonOptions) +; + } + GeneratedField::ParquetOptions => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("parquetOptions")); + } + file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::ParquetOptions) +; + } + GeneratedField::CsvOptions => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("csvOptions")); + } + file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::CsvOptions) +; } } } - Ok(GroupingSetNode { - expr: expr__.unwrap_or_default(), + Ok(FileTypeWriterOptions { + file_type: file_type__, }) } } - deserializer.deserialize_struct("datafusion.GroupingSetNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FileTypeWriterOptions", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for HashJoinExecNode { +impl serde::Serialize for FilterExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7801,84 +8459,46 @@ impl serde::Serialize for HashJoinExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.left.is_some() { - len += 1; - } - if self.right.is_some() { - len += 1; - } - if !self.on.is_empty() { - len += 1; - } - if self.join_type != 0 { - len += 1; - } - if self.partition_mode != 0 { + if self.input.is_some() { len += 1; } - if self.null_equals_null { + if self.expr.is_some() { len += 1; } - if self.filter.is_some() { + if self.default_filter_selectivity != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.HashJoinExecNode", len)?; - if let Some(v) = self.left.as_ref() { - struct_ser.serialize_field("left", v)?; - } - if let Some(v) = self.right.as_ref() { - struct_ser.serialize_field("right", v)?; - } - if !self.on.is_empty() { - struct_ser.serialize_field("on", &self.on)?; - } - if self.join_type != 0 { - let v = JoinType::try_from(self.join_type) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; - struct_ser.serialize_field("joinType", &v)?; - } - if self.partition_mode != 0 { - let v = PartitionMode::try_from(self.partition_mode) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; - struct_ser.serialize_field("partitionMode", &v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.FilterExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } - if self.null_equals_null { - struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } - if let Some(v) = self.filter.as_ref() { - struct_ser.serialize_field("filter", v)?; + if self.default_filter_selectivity != 0 { + struct_ser.serialize_field("defaultFilterSelectivity", &self.default_filter_selectivity)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for HashJoinExecNode { +impl<'de> serde::Deserialize<'de> for FilterExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "left", - "right", - "on", - "join_type", - "joinType", - "partition_mode", - "partitionMode", - "null_equals_null", - "nullEqualsNull", - "filter", + "input", + "expr", + "default_filter_selectivity", + "defaultFilterSelectivity", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Left, - Right, - On, - JoinType, - PartitionMode, - NullEqualsNull, - Filter, + Input, + Expr, + DefaultFilterSelectivity, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7900,13 +8520,9 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { E: serde::de::Error, { match value { - "left" => Ok(GeneratedField::Left), - "right" => Ok(GeneratedField::Right), - "on" => Ok(GeneratedField::On), - "joinType" | "join_type" => Ok(GeneratedField::JoinType), - "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), - "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), - "filter" => Ok(GeneratedField::Filter), + "input" => Ok(GeneratedField::Input), + "expr" => Ok(GeneratedField::Expr), + "defaultFilterSelectivity" | "default_filter_selectivity" => Ok(GeneratedField::DefaultFilterSelectivity), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7916,84 +8532,54 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = HashJoinExecNode; + type Value = FilterExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.HashJoinExecNode") + formatter.write_str("struct datafusion.FilterExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut left__ = None; - let mut right__ = None; - let mut on__ = None; - let mut join_type__ = None; - let mut partition_mode__ = None; - let mut null_equals_null__ = None; - let mut filter__ = None; + let mut input__ = None; + let mut expr__ = None; + let mut default_filter_selectivity__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Left => { - if left__.is_some() { - return Err(serde::de::Error::duplicate_field("left")); - } - left__ = map_.next_value()?; - } - GeneratedField::Right => { - if right__.is_some() { - return Err(serde::de::Error::duplicate_field("right")); - } - right__ = map_.next_value()?; - } - GeneratedField::On => { - if on__.is_some() { - return Err(serde::de::Error::duplicate_field("on")); - } - on__ = Some(map_.next_value()?); - } - GeneratedField::JoinType => { - if join_type__.is_some() { - return Err(serde::de::Error::duplicate_field("joinType")); - } - join_type__ = Some(map_.next_value::()? as i32); - } - GeneratedField::PartitionMode => { - if partition_mode__.is_some() { - return Err(serde::de::Error::duplicate_field("partitionMode")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - partition_mode__ = Some(map_.next_value::()? as i32); + input__ = map_.next_value()?; } - GeneratedField::NullEqualsNull => { - if null_equals_null__.is_some() { - return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - null_equals_null__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } - GeneratedField::Filter => { - if filter__.is_some() { - return Err(serde::de::Error::duplicate_field("filter")); + GeneratedField::DefaultFilterSelectivity => { + if default_filter_selectivity__.is_some() { + return Err(serde::de::Error::duplicate_field("defaultFilterSelectivity")); } - filter__ = map_.next_value()?; + default_filter_selectivity__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(HashJoinExecNode { - left: left__, - right: right__, - on: on__.unwrap_or_default(), - join_type: join_type__.unwrap_or_default(), - partition_mode: partition_mode__.unwrap_or_default(), - null_equals_null: null_equals_null__.unwrap_or_default(), - filter: filter__, + Ok(FilterExecNode { + input: input__, + expr: expr__, + default_filter_selectivity: default_filter_selectivity__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.HashJoinExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FilterExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for HashRepartition { +impl serde::Serialize for FixedSizeBinary { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8001,40 +8587,29 @@ impl serde::Serialize for HashRepartition { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.hash_expr.is_empty() { + if self.length != 0 { len += 1; } - if self.partition_count != 0 { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.HashRepartition", len)?; - if !self.hash_expr.is_empty() { - struct_ser.serialize_field("hashExpr", &self.hash_expr)?; - } - if self.partition_count != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("partitionCount", ToString::to_string(&self.partition_count).as_str())?; + let mut struct_ser = serializer.serialize_struct("datafusion.FixedSizeBinary", len)?; + if self.length != 0 { + struct_ser.serialize_field("length", &self.length)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for HashRepartition { +impl<'de> serde::Deserialize<'de> for FixedSizeBinary { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "hash_expr", - "hashExpr", - "partition_count", - "partitionCount", + "length", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - HashExpr, - PartitionCount, + Length, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8056,8 +8631,7 @@ impl<'de> serde::Deserialize<'de> for HashRepartition { E: serde::de::Error, { match value { - "hashExpr" | "hash_expr" => Ok(GeneratedField::HashExpr), - "partitionCount" | "partition_count" => Ok(GeneratedField::PartitionCount), + "length" => Ok(GeneratedField::Length), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8067,46 +8641,38 @@ impl<'de> serde::Deserialize<'de> for HashRepartition { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = HashRepartition; + type Value = FixedSizeBinary; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.HashRepartition") + formatter.write_str("struct datafusion.FixedSizeBinary") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut hash_expr__ = None; - let mut partition_count__ = None; + let mut length__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::HashExpr => { - if hash_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("hashExpr")); - } - hash_expr__ = Some(map_.next_value()?); - } - GeneratedField::PartitionCount => { - if partition_count__.is_some() { - return Err(serde::de::Error::duplicate_field("partitionCount")); + GeneratedField::Length => { + if length__.is_some() { + return Err(serde::de::Error::duplicate_field("length")); } - partition_count__ = + length__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } } - Ok(HashRepartition { - hash_expr: hash_expr__.unwrap_or_default(), - partition_count: partition_count__.unwrap_or_default(), + Ok(FixedSizeBinary { + length: length__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.HashRepartition", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FixedSizeBinary", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ILikeNode { +impl serde::Serialize for FixedSizeList { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8114,54 +8680,39 @@ impl serde::Serialize for ILikeNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.negated { - len += 1; - } - if self.expr.is_some() { - len += 1; - } - if self.pattern.is_some() { + if self.field_type.is_some() { len += 1; } - if !self.escape_char.is_empty() { + if self.list_size != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ILikeNode", len)?; - if self.negated { - struct_ser.serialize_field("negated", &self.negated)?; - } - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - if let Some(v) = self.pattern.as_ref() { - struct_ser.serialize_field("pattern", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.FixedSizeList", len)?; + if let Some(v) = self.field_type.as_ref() { + struct_ser.serialize_field("fieldType", v)?; } - if !self.escape_char.is_empty() { - struct_ser.serialize_field("escapeChar", &self.escape_char)?; + if self.list_size != 0 { + struct_ser.serialize_field("listSize", &self.list_size)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ILikeNode { +impl<'de> serde::Deserialize<'de> for FixedSizeList { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "negated", - "expr", - "pattern", - "escape_char", - "escapeChar", + "field_type", + "fieldType", + "list_size", + "listSize", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Negated, - Expr, - Pattern, - EscapeChar, + FieldType, + ListSize, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8183,10 +8734,8 @@ impl<'de> serde::Deserialize<'de> for ILikeNode { E: serde::de::Error, { match value { - "negated" => Ok(GeneratedField::Negated), - "expr" => Ok(GeneratedField::Expr), - "pattern" => Ok(GeneratedField::Pattern), - "escapeChar" | "escape_char" => Ok(GeneratedField::EscapeChar), + "fieldType" | "field_type" => Ok(GeneratedField::FieldType), + "listSize" | "list_size" => Ok(GeneratedField::ListSize), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8196,60 +8745,46 @@ impl<'de> serde::Deserialize<'de> for ILikeNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ILikeNode; + type Value = FixedSizeList; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ILikeNode") + formatter.write_str("struct datafusion.FixedSizeList") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut negated__ = None; - let mut expr__ = None; - let mut pattern__ = None; - let mut escape_char__ = None; + let mut field_type__ = None; + let mut list_size__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Negated => { - if negated__.is_some() { - return Err(serde::de::Error::duplicate_field("negated")); - } - negated__ = Some(map_.next_value()?); - } - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - GeneratedField::Pattern => { - if pattern__.is_some() { - return Err(serde::de::Error::duplicate_field("pattern")); + GeneratedField::FieldType => { + if field_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fieldType")); } - pattern__ = map_.next_value()?; + field_type__ = map_.next_value()?; } - GeneratedField::EscapeChar => { - if escape_char__.is_some() { - return Err(serde::de::Error::duplicate_field("escapeChar")); + GeneratedField::ListSize => { + if list_size__.is_some() { + return Err(serde::de::Error::duplicate_field("listSize")); } - escape_char__ = Some(map_.next_value()?); + list_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(ILikeNode { - negated: negated__.unwrap_or_default(), - expr: expr__, - pattern: pattern__, - escape_char: escape_char__.unwrap_or_default(), + Ok(FixedSizeList { + field_type: field_type__, + list_size: list_size__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.ILikeNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FixedSizeList", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for InListNode { +impl serde::Serialize for FullTableReference { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8257,45 +8792,45 @@ impl serde::Serialize for InListNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.catalog.is_empty() { len += 1; } - if !self.list.is_empty() { + if !self.schema.is_empty() { len += 1; } - if self.negated { + if !self.table.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.InListNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.FullTableReference", len)?; + if !self.catalog.is_empty() { + struct_ser.serialize_field("catalog", &self.catalog)?; } - if !self.list.is_empty() { - struct_ser.serialize_field("list", &self.list)?; + if !self.schema.is_empty() { + struct_ser.serialize_field("schema", &self.schema)?; } - if self.negated { - struct_ser.serialize_field("negated", &self.negated)?; + if !self.table.is_empty() { + struct_ser.serialize_field("table", &self.table)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for InListNode { +impl<'de> serde::Deserialize<'de> for FullTableReference { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "list", - "negated", + "catalog", + "schema", + "table", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - List, - Negated, + Catalog, + Schema, + Table, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8317,9 +8852,9 @@ impl<'de> serde::Deserialize<'de> for InListNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "list" => Ok(GeneratedField::List), - "negated" => Ok(GeneratedField::Negated), + "catalog" => Ok(GeneratedField::Catalog), + "schema" => Ok(GeneratedField::Schema), + "table" => Ok(GeneratedField::Table), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8329,52 +8864,52 @@ impl<'de> serde::Deserialize<'de> for InListNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = InListNode; + type Value = FullTableReference; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.InListNode") + formatter.write_str("struct datafusion.FullTableReference") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; - let mut list__ = None; - let mut negated__ = None; + let mut catalog__ = None; + let mut schema__ = None; + let mut table__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Catalog => { + if catalog__.is_some() { + return Err(serde::de::Error::duplicate_field("catalog")); } - expr__ = map_.next_value()?; + catalog__ = Some(map_.next_value()?); } - GeneratedField::List => { - if list__.is_some() { - return Err(serde::de::Error::duplicate_field("list")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - list__ = Some(map_.next_value()?); + schema__ = Some(map_.next_value()?); } - GeneratedField::Negated => { - if negated__.is_some() { - return Err(serde::de::Error::duplicate_field("negated")); + GeneratedField::Table => { + if table__.is_some() { + return Err(serde::de::Error::duplicate_field("table")); } - negated__ = Some(map_.next_value()?); + table__ = Some(map_.next_value()?); } } } - Ok(InListNode { - expr: expr__, - list: list__.unwrap_or_default(), - negated: negated__.unwrap_or_default(), + Ok(FullTableReference { + catalog: catalog__.unwrap_or_default(), + schema: schema__.unwrap_or_default(), + table: table__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FullTableReference", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for IntervalMonthDayNanoValue { +impl serde::Serialize for GetIndexedField { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8382,46 +8917,54 @@ impl serde::Serialize for IntervalMonthDayNanoValue { { use serde::ser::SerializeStruct; let mut len = 0; - if self.months != 0 { + if self.expr.is_some() { len += 1; } - if self.days != 0 { + if self.field.is_some() { len += 1; } - if self.nanos != 0 { - len += 1; + let mut struct_ser = serializer.serialize_struct("datafusion.GetIndexedField", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } - let mut struct_ser = serializer.serialize_struct("datafusion.IntervalMonthDayNanoValue", len)?; - if self.months != 0 { - struct_ser.serialize_field("months", &self.months)?; - } - if self.days != 0 { - struct_ser.serialize_field("days", &self.days)?; - } - if self.nanos != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("nanos", ToString::to_string(&self.nanos).as_str())?; + if let Some(v) = self.field.as_ref() { + match v { + get_indexed_field::Field::NamedStructField(v) => { + struct_ser.serialize_field("namedStructField", v)?; + } + get_indexed_field::Field::ListIndex(v) => { + struct_ser.serialize_field("listIndex", v)?; + } + get_indexed_field::Field::ListRange(v) => { + struct_ser.serialize_field("listRange", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { +impl<'de> serde::Deserialize<'de> for GetIndexedField { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "months", - "days", - "nanos", + "expr", + "named_struct_field", + "namedStructField", + "list_index", + "listIndex", + "list_range", + "listRange", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Months, - Days, - Nanos, + Expr, + NamedStructField, + ListIndex, + ListRange, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8443,9 +8986,10 @@ impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { E: serde::de::Error, { match value { - "months" => Ok(GeneratedField::Months), - "days" => Ok(GeneratedField::Days), - "nanos" => Ok(GeneratedField::Nanos), + "expr" => Ok(GeneratedField::Expr), + "namedStructField" | "named_struct_field" => Ok(GeneratedField::NamedStructField), + "listIndex" | "list_index" => Ok(GeneratedField::ListIndex), + "listRange" | "list_range" => Ok(GeneratedField::ListRange), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8455,132 +8999,59 @@ impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IntervalMonthDayNanoValue; + type Value = GetIndexedField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IntervalMonthDayNanoValue") + formatter.write_str("struct datafusion.GetIndexedField") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut months__ = None; - let mut days__ = None; - let mut nanos__ = None; + let mut expr__ = None; + let mut field__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Months => { - if months__.is_some() { - return Err(serde::de::Error::duplicate_field("months")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - months__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + expr__ = map_.next_value()?; } - GeneratedField::Days => { - if days__.is_some() { - return Err(serde::de::Error::duplicate_field("days")); + GeneratedField::NamedStructField => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("namedStructField")); } - days__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::NamedStructField) +; } - GeneratedField::Nanos => { - if nanos__.is_some() { - return Err(serde::de::Error::duplicate_field("nanos")); + GeneratedField::ListIndex => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("listIndex")); } - nanos__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListIndex) +; + } + GeneratedField::ListRange => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("listRange")); + } + field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListRange) +; } } } - Ok(IntervalMonthDayNanoValue { - months: months__.unwrap_or_default(), - days: days__.unwrap_or_default(), - nanos: nanos__.unwrap_or_default(), + Ok(GetIndexedField { + expr: expr__, + field: field__, }) } } - deserializer.deserialize_struct("datafusion.IntervalMonthDayNanoValue", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for IntervalUnit { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::YearMonth => "YearMonth", - Self::DayTime => "DayTime", - Self::MonthDayNano => "MonthDayNano", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for IntervalUnit { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "YearMonth", - "DayTime", - "MonthDayNano", - ]; - - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IntervalUnit; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } - - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "YearMonth" => Ok(IntervalUnit::YearMonth), - "DayTime" => Ok(IntervalUnit::DayTime), - "MonthDayNano" => Ok(IntervalUnit::MonthDayNano), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), - } - } - } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.GetIndexedField", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for IsFalse { +impl serde::Serialize for GlobalLimitExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8588,29 +9059,46 @@ impl serde::Serialize for IsFalse { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.input.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.IsFalse", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if self.skip != 0 { + len += 1; + } + if self.fetch != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.GlobalLimitExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if self.skip != 0 { + struct_ser.serialize_field("skip", &self.skip)?; + } + if self.fetch != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for IsFalse { +impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "input", + "skip", + "fetch", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Input, + Skip, + Fetch, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8632,7 +9120,9 @@ impl<'de> serde::Deserialize<'de> for IsFalse { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "input" => Ok(GeneratedField::Input), + "skip" => Ok(GeneratedField::Skip), + "fetch" => Ok(GeneratedField::Fetch), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8642,36 +9132,56 @@ impl<'de> serde::Deserialize<'de> for IsFalse { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsFalse; + type Value = GlobalLimitExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsFalse") + formatter.write_str("struct datafusion.GlobalLimitExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut input__ = None; + let mut skip__ = None; + let mut fetch__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - expr__ = map_.next_value()?; + input__ = map_.next_value()?; + } + GeneratedField::Skip => { + if skip__.is_some() { + return Err(serde::de::Error::duplicate_field("skip")); + } + skip__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(IsFalse { - expr: expr__, - }) - } + Ok(GlobalLimitExecNode { + input: input__, + skip: skip__.unwrap_or_default(), + fetch: fetch__.unwrap_or_default(), + }) + } } - deserializer.deserialize_struct("datafusion.IsFalse", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.GlobalLimitExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for IsNotFalse { +impl serde::Serialize for GroupingSetNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8679,17 +9189,17 @@ impl serde::Serialize for IsNotFalse { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.expr.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.IsNotFalse", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.GroupingSetNode", len)?; + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for IsNotFalse { +impl<'de> serde::Deserialize<'de> for GroupingSetNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -8733,13 +9243,13 @@ impl<'de> serde::Deserialize<'de> for IsNotFalse { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsNotFalse; + type Value = GroupingSetNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsNotFalse") + formatter.write_str("struct datafusion.GroupingSetNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -8750,19 +9260,19 @@ impl<'de> serde::Deserialize<'de> for IsNotFalse { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map_.next_value()?; + expr__ = Some(map_.next_value()?); } } } - Ok(IsNotFalse { - expr: expr__, + Ok(GroupingSetNode { + expr: expr__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.IsNotFalse", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.GroupingSetNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for IsNotNull { +impl serde::Serialize for HashJoinExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8770,29 +9280,84 @@ impl serde::Serialize for IsNotNull { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.left.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.IsNotNull", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if self.right.is_some() { + len += 1; + } + if !self.on.is_empty() { + len += 1; + } + if self.join_type != 0 { + len += 1; + } + if self.partition_mode != 0 { + len += 1; + } + if self.null_equals_null { + len += 1; + } + if self.filter.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.HashJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if !self.on.is_empty() { + struct_ser.serialize_field("on", &self.on)?; + } + if self.join_type != 0 { + let v = JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if self.partition_mode != 0 { + let v = PartitionMode::try_from(self.partition_mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; + struct_ser.serialize_field("partitionMode", &v)?; + } + if self.null_equals_null { + struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for IsNotNull { +impl<'de> serde::Deserialize<'de> for HashJoinExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "left", + "right", + "on", + "join_type", + "joinType", + "partition_mode", + "partitionMode", + "null_equals_null", + "nullEqualsNull", + "filter", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Left, + Right, + On, + JoinType, + PartitionMode, + NullEqualsNull, + Filter, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8814,7 +9379,13 @@ impl<'de> serde::Deserialize<'de> for IsNotNull { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "on" => Ok(GeneratedField::On), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), + "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "filter" => Ok(GeneratedField::Filter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8824,36 +9395,84 @@ impl<'de> serde::Deserialize<'de> for IsNotNull { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsNotNull; + type Value = HashJoinExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsNotNull") + formatter.write_str("struct datafusion.HashJoinExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut left__ = None; + let mut right__ = None; + let mut on__ = None; + let mut join_type__ = None; + let mut partition_mode__ = None; + let mut null_equals_null__ = None; + let mut filter__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); } - expr__ = map_.next_value()?; + left__ = map_.next_value()?; + } + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); + } + right__ = map_.next_value()?; + } + GeneratedField::On => { + if on__.is_some() { + return Err(serde::de::Error::duplicate_field("on")); + } + on__ = Some(map_.next_value()?); + } + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::PartitionMode => { + if partition_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionMode")); + } + partition_mode__ = Some(map_.next_value::()? as i32); + } + GeneratedField::NullEqualsNull => { + if null_equals_null__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + } + null_equals_null__ = Some(map_.next_value()?); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + filter__ = map_.next_value()?; } } } - Ok(IsNotNull { - expr: expr__, + Ok(HashJoinExecNode { + left: left__, + right: right__, + on: on__.unwrap_or_default(), + join_type: join_type__.unwrap_or_default(), + partition_mode: partition_mode__.unwrap_or_default(), + null_equals_null: null_equals_null__.unwrap_or_default(), + filter: filter__, }) } } - deserializer.deserialize_struct("datafusion.IsNotNull", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.HashJoinExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for IsNotTrue { +impl serde::Serialize for HashRepartition { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8861,29 +9480,40 @@ impl serde::Serialize for IsNotTrue { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.hash_expr.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.IsNotTrue", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if self.partition_count != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.HashRepartition", len)?; + if !self.hash_expr.is_empty() { + struct_ser.serialize_field("hashExpr", &self.hash_expr)?; + } + if self.partition_count != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("partitionCount", ToString::to_string(&self.partition_count).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for IsNotTrue { +impl<'de> serde::Deserialize<'de> for HashRepartition { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "hash_expr", + "hashExpr", + "partition_count", + "partitionCount", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + HashExpr, + PartitionCount, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8905,7 +9535,8 @@ impl<'de> serde::Deserialize<'de> for IsNotTrue { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "hashExpr" | "hash_expr" => Ok(GeneratedField::HashExpr), + "partitionCount" | "partition_count" => Ok(GeneratedField::PartitionCount), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8915,36 +9546,46 @@ impl<'de> serde::Deserialize<'de> for IsNotTrue { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsNotTrue; + type Value = HashRepartition; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsNotTrue") + formatter.write_str("struct datafusion.HashRepartition") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut hash_expr__ = None; + let mut partition_count__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::HashExpr => { + if hash_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("hashExpr")); } - expr__ = map_.next_value()?; + hash_expr__ = Some(map_.next_value()?); + } + GeneratedField::PartitionCount => { + if partition_count__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionCount")); + } + partition_count__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(IsNotTrue { - expr: expr__, + Ok(HashRepartition { + hash_expr: hash_expr__.unwrap_or_default(), + partition_count: partition_count__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.IsNotTrue", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.HashRepartition", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for IsNotUnknown { +impl serde::Serialize for ILikeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8952,29 +9593,54 @@ impl serde::Serialize for IsNotUnknown { { use serde::ser::SerializeStruct; let mut len = 0; + if self.negated { + len += 1; + } if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.IsNotUnknown", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if self.pattern.is_some() { + len += 1; } - struct_ser.end() + if !self.escape_char.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ILikeNode", len)?; + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; + } + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if let Some(v) = self.pattern.as_ref() { + struct_ser.serialize_field("pattern", v)?; + } + if !self.escape_char.is_empty() { + struct_ser.serialize_field("escapeChar", &self.escape_char)?; + } + struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for IsNotUnknown { +impl<'de> serde::Deserialize<'de> for ILikeNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "negated", "expr", + "pattern", + "escape_char", + "escapeChar", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { + Negated, Expr, + Pattern, + EscapeChar, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8996,7 +9662,10 @@ impl<'de> serde::Deserialize<'de> for IsNotUnknown { E: serde::de::Error, { match value { + "negated" => Ok(GeneratedField::Negated), "expr" => Ok(GeneratedField::Expr), + "pattern" => Ok(GeneratedField::Pattern), + "escapeChar" | "escape_char" => Ok(GeneratedField::EscapeChar), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9006,36 +9675,60 @@ impl<'de> serde::Deserialize<'de> for IsNotUnknown { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsNotUnknown; + type Value = ILikeNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsNotUnknown") + formatter.write_str("struct datafusion.ILikeNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { + let mut negated__ = None; let mut expr__ = None; + let mut pattern__ = None; + let mut escape_char__ = None; while let Some(k) = map_.next_key()? { match k { + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); + } + negated__ = Some(map_.next_value()?); + } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } expr__ = map_.next_value()?; } + GeneratedField::Pattern => { + if pattern__.is_some() { + return Err(serde::de::Error::duplicate_field("pattern")); + } + pattern__ = map_.next_value()?; + } + GeneratedField::EscapeChar => { + if escape_char__.is_some() { + return Err(serde::de::Error::duplicate_field("escapeChar")); + } + escape_char__ = Some(map_.next_value()?); + } } } - Ok(IsNotUnknown { + Ok(ILikeNode { + negated: negated__.unwrap_or_default(), expr: expr__, + pattern: pattern__, + escape_char: escape_char__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.IsNotUnknown", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ILikeNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for IsNull { +impl serde::Serialize for InListNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9046,14 +9739,26 @@ impl serde::Serialize for IsNull { if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.IsNull", len)?; + if !self.list.is_empty() { + len += 1; + } + if self.negated { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.InListNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } + if !self.list.is_empty() { + struct_ser.serialize_field("list", &self.list)?; + } + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; + } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for IsNull { +impl<'de> serde::Deserialize<'de> for InListNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -9061,11 +9766,15 @@ impl<'de> serde::Deserialize<'de> for IsNull { { const FIELDS: &[&str] = &[ "expr", + "list", + "negated", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, + List, + Negated, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9088,6 +9797,8 @@ impl<'de> serde::Deserialize<'de> for IsNull { { match value { "expr" => Ok(GeneratedField::Expr), + "list" => Ok(GeneratedField::List), + "negated" => Ok(GeneratedField::Negated), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9097,17 +9808,19 @@ impl<'de> serde::Deserialize<'de> for IsNull { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsNull; + type Value = InListNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsNull") + formatter.write_str("struct datafusion.InListNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; + let mut list__ = None; + let mut negated__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -9116,17 +9829,31 @@ impl<'de> serde::Deserialize<'de> for IsNull { } expr__ = map_.next_value()?; } + GeneratedField::List => { + if list__.is_some() { + return Err(serde::de::Error::duplicate_field("list")); + } + list__ = Some(map_.next_value()?); + } + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); + } + negated__ = Some(map_.next_value()?); + } } } - Ok(IsNull { + Ok(InListNode { expr: expr__, + list: list__.unwrap_or_default(), + negated: negated__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.IsNull", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for IsTrue { +impl serde::Serialize for InterleaveExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9134,29 +9861,29 @@ impl serde::Serialize for IsTrue { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.inputs.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.IsTrue", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.InterleaveExecNode", len)?; + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for IsTrue { +impl<'de> serde::Deserialize<'de> for InterleaveExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "inputs", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Inputs, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9178,7 +9905,7 @@ impl<'de> serde::Deserialize<'de> for IsTrue { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "inputs" => Ok(GeneratedField::Inputs), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9188,36 +9915,36 @@ impl<'de> serde::Deserialize<'de> for IsTrue { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsTrue; + type Value = InterleaveExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsTrue") + formatter.write_str("struct datafusion.InterleaveExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut inputs__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); } - expr__ = map_.next_value()?; + inputs__ = Some(map_.next_value()?); } } } - Ok(IsTrue { - expr: expr__, + Ok(InterleaveExecNode { + inputs: inputs__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.IsTrue", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.InterleaveExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for IsUnknown { +impl serde::Serialize for IntervalMonthDayNanoValue { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9225,29 +9952,46 @@ impl serde::Serialize for IsUnknown { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.months != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.IsUnknown", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if self.days != 0 { + len += 1; + } + if self.nanos != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.IntervalMonthDayNanoValue", len)?; + if self.months != 0 { + struct_ser.serialize_field("months", &self.months)?; + } + if self.days != 0 { + struct_ser.serialize_field("days", &self.days)?; + } + if self.nanos != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("nanos", ToString::to_string(&self.nanos).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for IsUnknown { +impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "months", + "days", + "nanos", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Months, + Days, + Nanos, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9269,7 +10013,9 @@ impl<'de> serde::Deserialize<'de> for IsUnknown { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "months" => Ok(GeneratedField::Months), + "days" => Ok(GeneratedField::Days), + "nanos" => Ok(GeneratedField::Nanos), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9279,63 +10025,87 @@ impl<'de> serde::Deserialize<'de> for IsUnknown { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsUnknown; + type Value = IntervalMonthDayNanoValue; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsUnknown") + formatter.write_str("struct datafusion.IntervalMonthDayNanoValue") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut months__ = None; + let mut days__ = None; + let mut nanos__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Months => { + if months__.is_some() { + return Err(serde::de::Error::duplicate_field("months")); } - expr__ = map_.next_value()?; + months__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Days => { + if days__.is_some() { + return Err(serde::de::Error::duplicate_field("days")); + } + days__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Nanos => { + if nanos__.is_some() { + return Err(serde::de::Error::duplicate_field("nanos")); + } + nanos__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(IsUnknown { - expr: expr__, + Ok(IntervalMonthDayNanoValue { + months: months__.unwrap_or_default(), + days: days__.unwrap_or_default(), + nanos: nanos__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.IsUnknown", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IntervalMonthDayNanoValue", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for JoinConstraint { +impl serde::Serialize for IntervalUnit { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { let variant = match self { - Self::On => "ON", - Self::Using => "USING", + Self::YearMonth => "YearMonth", + Self::DayTime => "DayTime", + Self::MonthDayNano => "MonthDayNano", }; serializer.serialize_str(variant) } } -impl<'de> serde::Deserialize<'de> for JoinConstraint { +impl<'de> serde::Deserialize<'de> for IntervalUnit { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "ON", - "USING", + "YearMonth", + "DayTime", + "MonthDayNano", ]; struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JoinConstraint; + type Value = IntervalUnit; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(formatter, "expected one of: {:?}", &FIELDS) @@ -9370,8 +10140,9 @@ impl<'de> serde::Deserialize<'de> for JoinConstraint { E: serde::de::Error, { match value { - "ON" => Ok(JoinConstraint::On), - "USING" => Ok(JoinConstraint::Using), + "YearMonth" => Ok(IntervalUnit::YearMonth), + "DayTime" => Ok(IntervalUnit::DayTime), + "MonthDayNano" => Ok(IntervalUnit::MonthDayNano), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -9379,7 +10150,7 @@ impl<'de> serde::Deserialize<'de> for JoinConstraint { deserializer.deserialize_any(GeneratedVisitor) } } -impl serde::Serialize for JoinFilter { +impl serde::Serialize for IsFalse { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9387,46 +10158,29 @@ impl serde::Serialize for JoinFilter { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expression.is_some() { + if self.expr.is_some() { len += 1; } - if !self.column_indices.is_empty() { - len += 1; - } - if self.schema.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.JoinFilter", len)?; - if let Some(v) = self.expression.as_ref() { - struct_ser.serialize_field("expression", v)?; - } - if !self.column_indices.is_empty() { - struct_ser.serialize_field("columnIndices", &self.column_indices)?; - } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.IsFalse", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for JoinFilter { +impl<'de> serde::Deserialize<'de> for IsFalse { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expression", - "column_indices", - "columnIndices", - "schema", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expression, - ColumnIndices, - Schema, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9448,9 +10202,7 @@ impl<'de> serde::Deserialize<'de> for JoinFilter { E: serde::de::Error, { match value { - "expression" => Ok(GeneratedField::Expression), - "columnIndices" | "column_indices" => Ok(GeneratedField::ColumnIndices), - "schema" => Ok(GeneratedField::Schema), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9460,52 +10212,36 @@ impl<'de> serde::Deserialize<'de> for JoinFilter { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JoinFilter; + type Value = IsFalse; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.JoinFilter") + formatter.write_str("struct datafusion.IsFalse") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expression__ = None; - let mut column_indices__ = None; - let mut schema__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expression => { - if expression__.is_some() { - return Err(serde::de::Error::duplicate_field("expression")); - } - expression__ = map_.next_value()?; - } - GeneratedField::ColumnIndices => { - if column_indices__.is_some() { - return Err(serde::de::Error::duplicate_field("columnIndices")); - } - column_indices__ = Some(map_.next_value()?); - } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - schema__ = map_.next_value()?; + expr__ = map_.next_value()?; } } } - Ok(JoinFilter { - expression: expression__, - column_indices: column_indices__.unwrap_or_default(), - schema: schema__, + Ok(IsFalse { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.JoinFilter", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsFalse", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for JoinNode { +impl serde::Serialize for IsNotFalse { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9513,94 +10249,29 @@ impl serde::Serialize for JoinNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.left.is_some() { - len += 1; - } - if self.right.is_some() { - len += 1; - } - if self.join_type != 0 { - len += 1; - } - if self.join_constraint != 0 { - len += 1; - } - if !self.left_join_key.is_empty() { - len += 1; - } - if !self.right_join_key.is_empty() { - len += 1; - } - if self.null_equals_null { - len += 1; - } - if self.filter.is_some() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.JoinNode", len)?; - if let Some(v) = self.left.as_ref() { - struct_ser.serialize_field("left", v)?; - } - if let Some(v) = self.right.as_ref() { - struct_ser.serialize_field("right", v)?; - } - if self.join_type != 0 { - let v = JoinType::try_from(self.join_type) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; - struct_ser.serialize_field("joinType", &v)?; - } - if self.join_constraint != 0 { - let v = JoinConstraint::try_from(self.join_constraint) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_constraint)))?; - struct_ser.serialize_field("joinConstraint", &v)?; - } - if !self.left_join_key.is_empty() { - struct_ser.serialize_field("leftJoinKey", &self.left_join_key)?; - } - if !self.right_join_key.is_empty() { - struct_ser.serialize_field("rightJoinKey", &self.right_join_key)?; - } - if self.null_equals_null { - struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; - } - if let Some(v) = self.filter.as_ref() { - struct_ser.serialize_field("filter", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.IsNotFalse", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for JoinNode { +impl<'de> serde::Deserialize<'de> for IsNotFalse { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "left", - "right", - "join_type", - "joinType", - "join_constraint", - "joinConstraint", - "left_join_key", - "leftJoinKey", - "right_join_key", - "rightJoinKey", - "null_equals_null", - "nullEqualsNull", - "filter", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Left, - Right, - JoinType, - JoinConstraint, - LeftJoinKey, - RightJoinKey, - NullEqualsNull, - Filter, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9622,14 +10293,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { E: serde::de::Error, { match value { - "left" => Ok(GeneratedField::Left), - "right" => Ok(GeneratedField::Right), - "joinType" | "join_type" => Ok(GeneratedField::JoinType), - "joinConstraint" | "join_constraint" => Ok(GeneratedField::JoinConstraint), - "leftJoinKey" | "left_join_key" => Ok(GeneratedField::LeftJoinKey), - "rightJoinKey" | "right_join_key" => Ok(GeneratedField::RightJoinKey), - "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), - "filter" => Ok(GeneratedField::Filter), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9639,92 +10303,36 @@ impl<'de> serde::Deserialize<'de> for JoinNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JoinNode; + type Value = IsNotFalse; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.JoinNode") + formatter.write_str("struct datafusion.IsNotFalse") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut left__ = None; - let mut right__ = None; - let mut join_type__ = None; - let mut join_constraint__ = None; - let mut left_join_key__ = None; - let mut right_join_key__ = None; - let mut null_equals_null__ = None; - let mut filter__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Left => { - if left__.is_some() { - return Err(serde::de::Error::duplicate_field("left")); - } - left__ = map_.next_value()?; - } - GeneratedField::Right => { - if right__.is_some() { - return Err(serde::de::Error::duplicate_field("right")); - } - right__ = map_.next_value()?; - } - GeneratedField::JoinType => { - if join_type__.is_some() { - return Err(serde::de::Error::duplicate_field("joinType")); - } - join_type__ = Some(map_.next_value::()? as i32); - } - GeneratedField::JoinConstraint => { - if join_constraint__.is_some() { - return Err(serde::de::Error::duplicate_field("joinConstraint")); - } - join_constraint__ = Some(map_.next_value::()? as i32); - } - GeneratedField::LeftJoinKey => { - if left_join_key__.is_some() { - return Err(serde::de::Error::duplicate_field("leftJoinKey")); - } - left_join_key__ = Some(map_.next_value()?); - } - GeneratedField::RightJoinKey => { - if right_join_key__.is_some() { - return Err(serde::de::Error::duplicate_field("rightJoinKey")); - } - right_join_key__ = Some(map_.next_value()?); - } - GeneratedField::NullEqualsNull => { - if null_equals_null__.is_some() { - return Err(serde::de::Error::duplicate_field("nullEqualsNull")); - } - null_equals_null__ = Some(map_.next_value()?); - } - GeneratedField::Filter => { - if filter__.is_some() { - return Err(serde::de::Error::duplicate_field("filter")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - filter__ = map_.next_value()?; + expr__ = map_.next_value()?; } } } - Ok(JoinNode { - left: left__, - right: right__, - join_type: join_type__.unwrap_or_default(), - join_constraint: join_constraint__.unwrap_or_default(), - left_join_key: left_join_key__.unwrap_or_default(), - right_join_key: right_join_key__.unwrap_or_default(), - null_equals_null: null_equals_null__.unwrap_or_default(), - filter: filter__, + Ok(IsNotFalse { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.JoinNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsNotFalse", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for JoinOn { +impl serde::Serialize for IsNotNull { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9732,37 +10340,29 @@ impl serde::Serialize for JoinOn { { use serde::ser::SerializeStruct; let mut len = 0; - if self.left.is_some() { - len += 1; - } - if self.right.is_some() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.JoinOn", len)?; - if let Some(v) = self.left.as_ref() { - struct_ser.serialize_field("left", v)?; - } - if let Some(v) = self.right.as_ref() { - struct_ser.serialize_field("right", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.IsNotNull", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for JoinOn { +impl<'de> serde::Deserialize<'de> for IsNotNull { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "left", - "right", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Left, - Right, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9784,8 +10384,7 @@ impl<'de> serde::Deserialize<'de> for JoinOn { E: serde::de::Error, { match value { - "left" => Ok(GeneratedField::Left), - "right" => Ok(GeneratedField::Right), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9795,204 +10394,127 @@ impl<'de> serde::Deserialize<'de> for JoinOn { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JoinOn; + type Value = IsNotNull; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.JoinOn") + formatter.write_str("struct datafusion.IsNotNull") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut left__ = None; - let mut right__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Left => { - if left__.is_some() { - return Err(serde::de::Error::duplicate_field("left")); - } - left__ = map_.next_value()?; - } - GeneratedField::Right => { - if right__.is_some() { - return Err(serde::de::Error::duplicate_field("right")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - right__ = map_.next_value()?; + expr__ = map_.next_value()?; } } } - Ok(JoinOn { - left: left__, - right: right__, + Ok(IsNotNull { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.JoinOn", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsNotNull", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for JoinSide { +impl serde::Serialize for IsNotTrue { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { - let variant = match self { - Self::LeftSide => "LEFT_SIDE", - Self::RightSide => "RIGHT_SIDE", - }; - serializer.serialize_str(variant) + use serde::ser::SerializeStruct; + let mut len = 0; + if self.expr.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.IsNotTrue", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for JoinSide { +impl<'de> serde::Deserialize<'de> for IsNotTrue { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "LEFT_SIDE", - "RIGHT_SIDE", + "expr", ]; - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JoinSide; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Expr, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result where - E: serde::de::Error, + D: serde::Deserializer<'de>, { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } + struct GeneratedVisitor; - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "LEFT_SIDE" => Ok(JoinSide::LeftSide), - "RIGHT_SIDE" => Ok(JoinSide::RightSide), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "expr" => Ok(GeneratedField::Expr), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } } + deserializer.deserialize_identifier(GeneratedVisitor) } } - deserializer.deserialize_any(GeneratedVisitor) - } -} -impl serde::Serialize for JoinType { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::Inner => "INNER", - Self::Left => "LEFT", - Self::Right => "RIGHT", - Self::Full => "FULL", - Self::Leftsemi => "LEFTSEMI", - Self::Leftanti => "LEFTANTI", - Self::Rightsemi => "RIGHTSEMI", - Self::Rightanti => "RIGHTANTI", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for JoinType { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "INNER", - "LEFT", - "RIGHT", - "FULL", - "LEFTSEMI", - "LEFTANTI", - "RIGHTSEMI", - "RIGHTANTI", - ]; - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JoinType; + type Value = IsNotTrue; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) + formatter.write_str("struct datafusion.IsNotTrue") } - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, { - match value { - "INNER" => Ok(JoinType::Inner), - "LEFT" => Ok(JoinType::Left), - "RIGHT" => Ok(JoinType::Right), - "FULL" => Ok(JoinType::Full), - "LEFTSEMI" => Ok(JoinType::Leftsemi), - "LEFTANTI" => Ok(JoinType::Leftanti), - "RIGHTSEMI" => Ok(JoinType::Rightsemi), - "RIGHTANTI" => Ok(JoinType::Rightanti), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + let mut expr__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = map_.next_value()?; + } + } } + Ok(IsNotTrue { + expr: expr__, + }) } } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsNotTrue", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for LikeNode { +impl serde::Serialize for IsNotUnknown { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -10000,54 +10522,29 @@ impl serde::Serialize for LikeNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.negated { - len += 1; - } if self.expr.is_some() { len += 1; } - if self.pattern.is_some() { - len += 1; - } - if !self.escape_char.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.LikeNode", len)?; - if self.negated { - struct_ser.serialize_field("negated", &self.negated)?; - } + let mut struct_ser = serializer.serialize_struct("datafusion.IsNotUnknown", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } - if let Some(v) = self.pattern.as_ref() { - struct_ser.serialize_field("pattern", v)?; - } - if !self.escape_char.is_empty() { - struct_ser.serialize_field("escapeChar", &self.escape_char)?; - } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for LikeNode { +impl<'de> serde::Deserialize<'de> for IsNotUnknown { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "negated", "expr", - "pattern", - "escape_char", - "escapeChar", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Negated, Expr, - Pattern, - EscapeChar, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10069,10 +10566,7 @@ impl<'de> serde::Deserialize<'de> for LikeNode { E: serde::de::Error, { match value { - "negated" => Ok(GeneratedField::Negated), "expr" => Ok(GeneratedField::Expr), - "pattern" => Ok(GeneratedField::Pattern), - "escapeChar" | "escape_char" => Ok(GeneratedField::EscapeChar), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10082,60 +10576,36 @@ impl<'de> serde::Deserialize<'de> for LikeNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LikeNode; + type Value = IsNotUnknown; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LikeNode") + formatter.write_str("struct datafusion.IsNotUnknown") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut negated__ = None; let mut expr__ = None; - let mut pattern__ = None; - let mut escape_char__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Negated => { - if negated__.is_some() { - return Err(serde::de::Error::duplicate_field("negated")); - } - negated__ = Some(map_.next_value()?); - } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } expr__ = map_.next_value()?; } - GeneratedField::Pattern => { - if pattern__.is_some() { - return Err(serde::de::Error::duplicate_field("pattern")); - } - pattern__ = map_.next_value()?; - } - GeneratedField::EscapeChar => { - if escape_char__.is_some() { - return Err(serde::de::Error::duplicate_field("escapeChar")); - } - escape_char__ = Some(map_.next_value()?); - } } } - Ok(LikeNode { - negated: negated__.unwrap_or_default(), + Ok(IsNotUnknown { expr: expr__, - pattern: pattern__, - escape_char: escape_char__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.LikeNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsNotUnknown", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for LimitNode { +impl serde::Serialize for IsNull { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -10143,47 +10613,29 @@ impl serde::Serialize for LimitNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { - len += 1; - } - if self.skip != 0 { - len += 1; - } - if self.fetch != 0 { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.LimitNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; - } - if self.skip != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("skip", ToString::to_string(&self.skip).as_str())?; - } - if self.fetch != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; + let mut struct_ser = serializer.serialize_struct("datafusion.IsNull", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for LimitNode { +impl<'de> serde::Deserialize<'de> for IsNull { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "skip", - "fetch", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - Skip, - Fetch, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10205,9 +10657,7 @@ impl<'de> serde::Deserialize<'de> for LimitNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "skip" => Ok(GeneratedField::Skip), - "fetch" => Ok(GeneratedField::Fetch), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10217,56 +10667,36 @@ impl<'de> serde::Deserialize<'de> for LimitNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LimitNode; + type Value = IsNull; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LimitNode") + formatter.write_str("struct datafusion.IsNull") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut skip__ = None; - let mut fetch__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map_.next_value()?; - } - GeneratedField::Skip => { - if skip__.is_some() { - return Err(serde::de::Error::duplicate_field("skip")); - } - skip__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::Fetch => { - if fetch__.is_some() { - return Err(serde::de::Error::duplicate_field("fetch")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - fetch__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + expr__ = map_.next_value()?; } } } - Ok(LimitNode { - input: input__, - skip: skip__.unwrap_or_default(), - fetch: fetch__.unwrap_or_default(), + Ok(IsNull { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.LimitNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsNull", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for List { +impl serde::Serialize for IsTrue { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -10274,30 +10704,29 @@ impl serde::Serialize for List { { use serde::ser::SerializeStruct; let mut len = 0; - if self.field_type.is_some() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.List", len)?; - if let Some(v) = self.field_type.as_ref() { - struct_ser.serialize_field("fieldType", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.IsTrue", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for List { +impl<'de> serde::Deserialize<'de> for IsTrue { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "field_type", - "fieldType", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - FieldType, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10319,7 +10748,7 @@ impl<'de> serde::Deserialize<'de> for List { E: serde::de::Error, { match value { - "fieldType" | "field_type" => Ok(GeneratedField::FieldType), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10329,36 +10758,36 @@ impl<'de> serde::Deserialize<'de> for List { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = List; + type Value = IsTrue; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.List") + formatter.write_str("struct datafusion.IsTrue") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut field_type__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::FieldType => { - if field_type__.is_some() { - return Err(serde::de::Error::duplicate_field("fieldType")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - field_type__ = map_.next_value()?; + expr__ = map_.next_value()?; } } } - Ok(List { - field_type: field_type__, + Ok(IsTrue { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.List", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsTrue", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ListIndex { +impl serde::Serialize for IsUnknown { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -10366,29 +10795,29 @@ impl serde::Serialize for ListIndex { { use serde::ser::SerializeStruct; let mut len = 0; - if self.key.is_some() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ListIndex", len)?; - if let Some(v) = self.key.as_ref() { - struct_ser.serialize_field("key", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.IsUnknown", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ListIndex { +impl<'de> serde::Deserialize<'de> for IsUnknown { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "key", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Key, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10410,7 +10839,7 @@ impl<'de> serde::Deserialize<'de> for ListIndex { E: serde::de::Error, { match value { - "key" => Ok(GeneratedField::Key), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10420,127 +10849,107 @@ impl<'de> serde::Deserialize<'de> for ListIndex { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ListIndex; + type Value = IsUnknown; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ListIndex") + formatter.write_str("struct datafusion.IsUnknown") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut key__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Key => { - if key__.is_some() { - return Err(serde::de::Error::duplicate_field("key")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - key__ = map_.next_value()?; + expr__ = map_.next_value()?; } } } - Ok(ListIndex { - key: key__, + Ok(IsUnknown { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.ListIndex", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsUnknown", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ListIndexExpr { +impl serde::Serialize for JoinConstraint { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.key.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ListIndexExpr", len)?; - if let Some(v) = self.key.as_ref() { - struct_ser.serialize_field("key", v)?; - } - struct_ser.end() + let variant = match self { + Self::On => "ON", + Self::Using => "USING", + }; + serializer.serialize_str(variant) } } -impl<'de> serde::Deserialize<'de> for ListIndexExpr { +impl<'de> serde::Deserialize<'de> for JoinConstraint { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "key", + "ON", + "USING", ]; - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Key, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "key" => Ok(GeneratedField::Key), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ListIndexExpr; + type Value = JoinConstraint; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ListIndexExpr") + write!(formatter, "expected one of: {:?}", &FIELDS) } - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, { - let mut key__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Key => { - if key__.is_some() { - return Err(serde::de::Error::duplicate_field("key")); - } - key__ = map_.next_value()?; - } - } + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "ON" => Ok(JoinConstraint::On), + "USING" => Ok(JoinConstraint::Using), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } - Ok(ListIndexExpr { - key: key__, - }) } } - deserializer.deserialize_struct("datafusion.ListIndexExpr", FIELDS, GeneratedVisitor) + deserializer.deserialize_any(GeneratedVisitor) } } -impl serde::Serialize for ListRange { +impl serde::Serialize for JoinFilter { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -10548,37 +10957,46 @@ impl serde::Serialize for ListRange { { use serde::ser::SerializeStruct; let mut len = 0; - if self.start.is_some() { + if self.expression.is_some() { len += 1; } - if self.stop.is_some() { + if !self.column_indices.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ListRange", len)?; - if let Some(v) = self.start.as_ref() { - struct_ser.serialize_field("start", v)?; + if self.schema.is_some() { + len += 1; } - if let Some(v) = self.stop.as_ref() { - struct_ser.serialize_field("stop", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.JoinFilter", len)?; + if let Some(v) = self.expression.as_ref() { + struct_ser.serialize_field("expression", v)?; + } + if !self.column_indices.is_empty() { + struct_ser.serialize_field("columnIndices", &self.column_indices)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ListRange { +impl<'de> serde::Deserialize<'de> for JoinFilter { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "start", - "stop", + "expression", + "column_indices", + "columnIndices", + "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Start, - Stop, + Expression, + ColumnIndices, + Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10600,8 +11018,9 @@ impl<'de> serde::Deserialize<'de> for ListRange { E: serde::de::Error, { match value { - "start" => Ok(GeneratedField::Start), - "stop" => Ok(GeneratedField::Stop), + "expression" => Ok(GeneratedField::Expression), + "columnIndices" | "column_indices" => Ok(GeneratedField::ColumnIndices), + "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10611,44 +11030,52 @@ impl<'de> serde::Deserialize<'de> for ListRange { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ListRange; + type Value = JoinFilter; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ListRange") + formatter.write_str("struct datafusion.JoinFilter") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut start__ = None; - let mut stop__ = None; + let mut expression__ = None; + let mut column_indices__ = None; + let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Start => { - if start__.is_some() { - return Err(serde::de::Error::duplicate_field("start")); + GeneratedField::Expression => { + if expression__.is_some() { + return Err(serde::de::Error::duplicate_field("expression")); } - start__ = map_.next_value()?; + expression__ = map_.next_value()?; } - GeneratedField::Stop => { - if stop__.is_some() { - return Err(serde::de::Error::duplicate_field("stop")); + GeneratedField::ColumnIndices => { + if column_indices__.is_some() { + return Err(serde::de::Error::duplicate_field("columnIndices")); } - stop__ = map_.next_value()?; + column_indices__ = Some(map_.next_value()?); + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; } } } - Ok(ListRange { - start: start__, - stop: stop__, + Ok(JoinFilter { + expression: expression__, + column_indices: column_indices__.unwrap_or_default(), + schema: schema__, }) } } - deserializer.deserialize_struct("datafusion.ListRange", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.JoinFilter", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ListRangeExpr { +impl serde::Serialize for JoinNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -10656,37 +11083,94 @@ impl serde::Serialize for ListRangeExpr { { use serde::ser::SerializeStruct; let mut len = 0; - if self.start.is_some() { + if self.left.is_some() { len += 1; } - if self.stop.is_some() { + if self.right.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ListRangeExpr", len)?; - if let Some(v) = self.start.as_ref() { - struct_ser.serialize_field("start", v)?; + if self.join_type != 0 { + len += 1; } - if let Some(v) = self.stop.as_ref() { - struct_ser.serialize_field("stop", v)?; + if self.join_constraint != 0 { + len += 1; + } + if !self.left_join_key.is_empty() { + len += 1; + } + if !self.right_join_key.is_empty() { + len += 1; + } + if self.null_equals_null { + len += 1; + } + if self.filter.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.JoinNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if self.join_type != 0 { + let v = JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if self.join_constraint != 0 { + let v = JoinConstraint::try_from(self.join_constraint) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_constraint)))?; + struct_ser.serialize_field("joinConstraint", &v)?; + } + if !self.left_join_key.is_empty() { + struct_ser.serialize_field("leftJoinKey", &self.left_join_key)?; + } + if !self.right_join_key.is_empty() { + struct_ser.serialize_field("rightJoinKey", &self.right_join_key)?; + } + if self.null_equals_null { + struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ListRangeExpr { +impl<'de> serde::Deserialize<'de> for JoinNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "start", - "stop", + "left", + "right", + "join_type", + "joinType", + "join_constraint", + "joinConstraint", + "left_join_key", + "leftJoinKey", + "right_join_key", + "rightJoinKey", + "null_equals_null", + "nullEqualsNull", + "filter", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Start, - Stop, + Left, + Right, + JoinType, + JoinConstraint, + LeftJoinKey, + RightJoinKey, + NullEqualsNull, + Filter, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10708,8 +11192,14 @@ impl<'de> serde::Deserialize<'de> for ListRangeExpr { E: serde::de::Error, { match value { - "start" => Ok(GeneratedField::Start), - "stop" => Ok(GeneratedField::Stop), + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "joinConstraint" | "join_constraint" => Ok(GeneratedField::JoinConstraint), + "leftJoinKey" | "left_join_key" => Ok(GeneratedField::LeftJoinKey), + "rightJoinKey" | "right_join_key" => Ok(GeneratedField::RightJoinKey), + "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "filter" => Ok(GeneratedField::Filter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10719,44 +11209,92 @@ impl<'de> serde::Deserialize<'de> for ListRangeExpr { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ListRangeExpr; + type Value = JoinNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ListRangeExpr") + formatter.write_str("struct datafusion.JoinNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut start__ = None; - let mut stop__ = None; + let mut left__ = None; + let mut right__ = None; + let mut join_type__ = None; + let mut join_constraint__ = None; + let mut left_join_key__ = None; + let mut right_join_key__ = None; + let mut null_equals_null__ = None; + let mut filter__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Start => { - if start__.is_some() { - return Err(serde::de::Error::duplicate_field("start")); + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); } - start__ = map_.next_value()?; + left__ = map_.next_value()?; } - GeneratedField::Stop => { - if stop__.is_some() { - return Err(serde::de::Error::duplicate_field("stop")); + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); } - stop__ = map_.next_value()?; + right__ = map_.next_value()?; } - } - } - Ok(ListRangeExpr { - start: start__, - stop: stop__, - }) + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::JoinConstraint => { + if join_constraint__.is_some() { + return Err(serde::de::Error::duplicate_field("joinConstraint")); + } + join_constraint__ = Some(map_.next_value::()? as i32); + } + GeneratedField::LeftJoinKey => { + if left_join_key__.is_some() { + return Err(serde::de::Error::duplicate_field("leftJoinKey")); + } + left_join_key__ = Some(map_.next_value()?); + } + GeneratedField::RightJoinKey => { + if right_join_key__.is_some() { + return Err(serde::de::Error::duplicate_field("rightJoinKey")); + } + right_join_key__ = Some(map_.next_value()?); + } + GeneratedField::NullEqualsNull => { + if null_equals_null__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + } + null_equals_null__ = Some(map_.next_value()?); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + filter__ = map_.next_value()?; + } + } + } + Ok(JoinNode { + left: left__, + right: right__, + join_type: join_type__.unwrap_or_default(), + join_constraint: join_constraint__.unwrap_or_default(), + left_join_key: left_join_key__.unwrap_or_default(), + right_join_key: right_join_key__.unwrap_or_default(), + null_equals_null: null_equals_null__.unwrap_or_default(), + filter: filter__, + }) } } - deserializer.deserialize_struct("datafusion.ListRangeExpr", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.JoinNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ListingTableScanNode { +impl serde::Serialize for JoinOn { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -10764,129 +11302,37 @@ impl serde::Serialize for ListingTableScanNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.table_name.is_some() { - len += 1; - } - if !self.paths.is_empty() { - len += 1; - } - if !self.file_extension.is_empty() { - len += 1; - } - if self.projection.is_some() { - len += 1; - } - if self.schema.is_some() { - len += 1; - } - if !self.filters.is_empty() { - len += 1; - } - if !self.table_partition_cols.is_empty() { - len += 1; - } - if self.collect_stat { - len += 1; - } - if self.target_partitions != 0 { - len += 1; - } - if !self.file_sort_order.is_empty() { + if self.left.is_some() { len += 1; } - if self.file_format_type.is_some() { + if self.right.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ListingTableScanNode", len)?; - if let Some(v) = self.table_name.as_ref() { - struct_ser.serialize_field("tableName", v)?; - } - if !self.paths.is_empty() { - struct_ser.serialize_field("paths", &self.paths)?; - } - if !self.file_extension.is_empty() { - struct_ser.serialize_field("fileExtension", &self.file_extension)?; - } - if let Some(v) = self.projection.as_ref() { - struct_ser.serialize_field("projection", v)?; - } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; - } - if !self.filters.is_empty() { - struct_ser.serialize_field("filters", &self.filters)?; - } - if !self.table_partition_cols.is_empty() { - struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; - } - if self.collect_stat { - struct_ser.serialize_field("collectStat", &self.collect_stat)?; - } - if self.target_partitions != 0 { - struct_ser.serialize_field("targetPartitions", &self.target_partitions)?; - } - if !self.file_sort_order.is_empty() { - struct_ser.serialize_field("fileSortOrder", &self.file_sort_order)?; + let mut struct_ser = serializer.serialize_struct("datafusion.JoinOn", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; } - if let Some(v) = self.file_format_type.as_ref() { - match v { - listing_table_scan_node::FileFormatType::Csv(v) => { - struct_ser.serialize_field("csv", v)?; - } - listing_table_scan_node::FileFormatType::Parquet(v) => { - struct_ser.serialize_field("parquet", v)?; - } - listing_table_scan_node::FileFormatType::Avro(v) => { - struct_ser.serialize_field("avro", v)?; - } - } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ListingTableScanNode { +impl<'de> serde::Deserialize<'de> for JoinOn { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "table_name", - "tableName", - "paths", - "file_extension", - "fileExtension", - "projection", - "schema", - "filters", - "table_partition_cols", - "tablePartitionCols", - "collect_stat", - "collectStat", - "target_partitions", - "targetPartitions", - "file_sort_order", - "fileSortOrder", - "csv", - "parquet", - "avro", + "left", + "right", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - TableName, - Paths, - FileExtension, - Projection, - Schema, - Filters, - TablePartitionCols, - CollectStat, - TargetPartitions, - FileSortOrder, - Csv, - Parquet, - Avro, + Left, + Right, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10908,19 +11354,8 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { E: serde::de::Error, { match value { - "tableName" | "table_name" => Ok(GeneratedField::TableName), - "paths" => Ok(GeneratedField::Paths), - "fileExtension" | "file_extension" => Ok(GeneratedField::FileExtension), - "projection" => Ok(GeneratedField::Projection), - "schema" => Ok(GeneratedField::Schema), - "filters" => Ok(GeneratedField::Filters), - "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), - "collectStat" | "collect_stat" => Ok(GeneratedField::CollectStat), - "targetPartitions" | "target_partitions" => Ok(GeneratedField::TargetPartitions), - "fileSortOrder" | "file_sort_order" => Ok(GeneratedField::FileSortOrder), - "csv" => Ok(GeneratedField::Csv), - "parquet" => Ok(GeneratedField::Parquet), - "avro" => Ok(GeneratedField::Avro), + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10930,133 +11365,204 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ListingTableScanNode; + type Value = JoinOn; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ListingTableScanNode") + formatter.write_str("struct datafusion.JoinOn") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut table_name__ = None; - let mut paths__ = None; - let mut file_extension__ = None; - let mut projection__ = None; - let mut schema__ = None; - let mut filters__ = None; - let mut table_partition_cols__ = None; - let mut collect_stat__ = None; - let mut target_partitions__ = None; - let mut file_sort_order__ = None; - let mut file_format_type__ = None; + let mut left__ = None; + let mut right__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::TableName => { - if table_name__.is_some() { - return Err(serde::de::Error::duplicate_field("tableName")); - } - table_name__ = map_.next_value()?; - } - GeneratedField::Paths => { - if paths__.is_some() { - return Err(serde::de::Error::duplicate_field("paths")); - } - paths__ = Some(map_.next_value()?); - } - GeneratedField::FileExtension => { - if file_extension__.is_some() { - return Err(serde::de::Error::duplicate_field("fileExtension")); - } - file_extension__ = Some(map_.next_value()?); - } - GeneratedField::Projection => { - if projection__.is_some() { - return Err(serde::de::Error::duplicate_field("projection")); - } - projection__ = map_.next_value()?; - } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); - } - schema__ = map_.next_value()?; - } - GeneratedField::Filters => { - if filters__.is_some() { - return Err(serde::de::Error::duplicate_field("filters")); - } - filters__ = Some(map_.next_value()?); - } - GeneratedField::TablePartitionCols => { - if table_partition_cols__.is_some() { - return Err(serde::de::Error::duplicate_field("tablePartitionCols")); + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); } - table_partition_cols__ = Some(map_.next_value()?); + left__ = map_.next_value()?; } - GeneratedField::CollectStat => { - if collect_stat__.is_some() { - return Err(serde::de::Error::duplicate_field("collectStat")); + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); } - collect_stat__ = Some(map_.next_value()?); - } - GeneratedField::TargetPartitions => { - if target_partitions__.is_some() { - return Err(serde::de::Error::duplicate_field("targetPartitions")); - } - target_partitions__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::FileSortOrder => { - if file_sort_order__.is_some() { - return Err(serde::de::Error::duplicate_field("fileSortOrder")); - } - file_sort_order__ = Some(map_.next_value()?); - } - GeneratedField::Csv => { - if file_format_type__.is_some() { - return Err(serde::de::Error::duplicate_field("csv")); - } - file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Csv) -; - } - GeneratedField::Parquet => { - if file_format_type__.is_some() { - return Err(serde::de::Error::duplicate_field("parquet")); - } - file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Parquet) -; - } - GeneratedField::Avro => { - if file_format_type__.is_some() { - return Err(serde::de::Error::duplicate_field("avro")); - } - file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Avro) -; + right__ = map_.next_value()?; } } } - Ok(ListingTableScanNode { - table_name: table_name__, - paths: paths__.unwrap_or_default(), - file_extension: file_extension__.unwrap_or_default(), - projection: projection__, - schema: schema__, - filters: filters__.unwrap_or_default(), - table_partition_cols: table_partition_cols__.unwrap_or_default(), - collect_stat: collect_stat__.unwrap_or_default(), - target_partitions: target_partitions__.unwrap_or_default(), - file_sort_order: file_sort_order__.unwrap_or_default(), - file_format_type: file_format_type__, + Ok(JoinOn { + left: left__, + right: right__, }) } } - deserializer.deserialize_struct("datafusion.ListingTableScanNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.JoinOn", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for LocalLimitExecNode { +impl serde::Serialize for JoinSide { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::LeftSide => "LEFT_SIDE", + Self::RightSide => "RIGHT_SIDE", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for JoinSide { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "LEFT_SIDE", + "RIGHT_SIDE", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = JoinSide; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "LEFT_SIDE" => Ok(JoinSide::LeftSide), + "RIGHT_SIDE" => Ok(JoinSide::RightSide), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for JoinType { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Inner => "INNER", + Self::Left => "LEFT", + Self::Right => "RIGHT", + Self::Full => "FULL", + Self::Leftsemi => "LEFTSEMI", + Self::Leftanti => "LEFTANTI", + Self::Rightsemi => "RIGHTSEMI", + Self::Rightanti => "RIGHTANTI", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for JoinType { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "INNER", + "LEFT", + "RIGHT", + "FULL", + "LEFTSEMI", + "LEFTANTI", + "RIGHTSEMI", + "RIGHTANTI", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = JoinType; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "INNER" => Ok(JoinType::Inner), + "LEFT" => Ok(JoinType::Left), + "RIGHT" => Ok(JoinType::Right), + "FULL" => Ok(JoinType::Full), + "LEFTSEMI" => Ok(JoinType::Leftsemi), + "LEFTANTI" => Ok(JoinType::Leftanti), + "RIGHTSEMI" => Ok(JoinType::Rightsemi), + "RIGHTANTI" => Ok(JoinType::Rightanti), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for JsonSink { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -11064,37 +11570,29 @@ impl serde::Serialize for LocalLimitExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if self.config.is_some() { len += 1; } - if self.fetch != 0 { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.LocalLimitExecNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; - } - if self.fetch != 0 { - struct_ser.serialize_field("fetch", &self.fetch)?; + let mut struct_ser = serializer.serialize_struct("datafusion.JsonSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for LocalLimitExecNode { +impl<'de> serde::Deserialize<'de> for JsonSink { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "fetch", + "config", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - Fetch, + Config, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -11116,8 +11614,7 @@ impl<'de> serde::Deserialize<'de> for LocalLimitExecNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "fetch" => Ok(GeneratedField::Fetch), + "config" => Ok(GeneratedField::Config), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -11127,46 +11624,36 @@ impl<'de> serde::Deserialize<'de> for LocalLimitExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LocalLimitExecNode; + type Value = JsonSink; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LocalLimitExecNode") + formatter.write_str("struct datafusion.JsonSink") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut fetch__ = None; + let mut config__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map_.next_value()?; - } - GeneratedField::Fetch => { - if fetch__.is_some() { - return Err(serde::de::Error::duplicate_field("fetch")); + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); } - fetch__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + config__ = map_.next_value()?; } } } - Ok(LocalLimitExecNode { - input: input__, - fetch: fetch__.unwrap_or_default(), + Ok(JsonSink { + config: config__, }) } } - deserializer.deserialize_struct("datafusion.LocalLimitExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.JsonSink", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for LogicalExprList { +impl serde::Serialize for JsonSinkExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -11174,29 +11661,55 @@ impl serde::Serialize for LogicalExprList { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.expr.is_empty() { + if self.input.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExprList", len)?; - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; + if self.sink.is_some() { + len += 1; + } + if self.sink_schema.is_some() { + len += 1; + } + if self.sort_order.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.JsonSinkExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; + } + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; + } + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for LogicalExprList { +impl<'de> serde::Deserialize<'de> for JsonSinkExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "input", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Input, + Sink, + SinkSchema, + SortOrder, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -11218,7 +11731,10 @@ impl<'de> serde::Deserialize<'de> for LogicalExprList { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "input" => Ok(GeneratedField::Input), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -11228,36 +11744,60 @@ impl<'de> serde::Deserialize<'de> for LogicalExprList { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LogicalExprList; + type Value = JsonSinkExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LogicalExprList") + formatter.write_str("struct datafusion.JsonSinkExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut input__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - expr__ = Some(map_.next_value()?); + input__ = map_.next_value()?; + } + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); + } + sink__ = map_.next_value()?; + } + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); + } + sink_schema__ = map_.next_value()?; + } + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); + } + sort_order__ = map_.next_value()?; } } } - Ok(LogicalExprList { - expr: expr__.unwrap_or_default(), + Ok(JsonSinkExecNode { + input: input__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, }) } } - deserializer.deserialize_struct("datafusion.LogicalExprList", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.JsonSinkExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for LogicalExprNode { +impl serde::Serialize for JsonWriterOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -11265,219 +11805,31 @@ impl serde::Serialize for LogicalExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr_type.is_some() { + if self.compression != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExprNode", len)?; - if let Some(v) = self.expr_type.as_ref() { - match v { - logical_expr_node::ExprType::Column(v) => { - struct_ser.serialize_field("column", v)?; - } - logical_expr_node::ExprType::Alias(v) => { - struct_ser.serialize_field("alias", v)?; - } - logical_expr_node::ExprType::Literal(v) => { - struct_ser.serialize_field("literal", v)?; - } - logical_expr_node::ExprType::BinaryExpr(v) => { - struct_ser.serialize_field("binaryExpr", v)?; - } - logical_expr_node::ExprType::AggregateExpr(v) => { - struct_ser.serialize_field("aggregateExpr", v)?; - } - logical_expr_node::ExprType::IsNullExpr(v) => { - struct_ser.serialize_field("isNullExpr", v)?; - } - logical_expr_node::ExprType::IsNotNullExpr(v) => { - struct_ser.serialize_field("isNotNullExpr", v)?; - } - logical_expr_node::ExprType::NotExpr(v) => { - struct_ser.serialize_field("notExpr", v)?; - } - logical_expr_node::ExprType::Between(v) => { - struct_ser.serialize_field("between", v)?; - } - logical_expr_node::ExprType::Case(v) => { - struct_ser.serialize_field("case", v)?; - } - logical_expr_node::ExprType::Cast(v) => { - struct_ser.serialize_field("cast", v)?; - } - logical_expr_node::ExprType::Sort(v) => { - struct_ser.serialize_field("sort", v)?; - } - logical_expr_node::ExprType::Negative(v) => { - struct_ser.serialize_field("negative", v)?; - } - logical_expr_node::ExprType::InList(v) => { - struct_ser.serialize_field("inList", v)?; - } - logical_expr_node::ExprType::Wildcard(v) => { - struct_ser.serialize_field("wildcard", v)?; - } - logical_expr_node::ExprType::ScalarFunction(v) => { - struct_ser.serialize_field("scalarFunction", v)?; - } - logical_expr_node::ExprType::TryCast(v) => { - struct_ser.serialize_field("tryCast", v)?; - } - logical_expr_node::ExprType::WindowExpr(v) => { - struct_ser.serialize_field("windowExpr", v)?; - } - logical_expr_node::ExprType::AggregateUdfExpr(v) => { - struct_ser.serialize_field("aggregateUdfExpr", v)?; - } - logical_expr_node::ExprType::ScalarUdfExpr(v) => { - struct_ser.serialize_field("scalarUdfExpr", v)?; - } - logical_expr_node::ExprType::GetIndexedField(v) => { - struct_ser.serialize_field("getIndexedField", v)?; - } - logical_expr_node::ExprType::GroupingSet(v) => { - struct_ser.serialize_field("groupingSet", v)?; - } - logical_expr_node::ExprType::Cube(v) => { - struct_ser.serialize_field("cube", v)?; - } - logical_expr_node::ExprType::Rollup(v) => { - struct_ser.serialize_field("rollup", v)?; - } - logical_expr_node::ExprType::IsTrue(v) => { - struct_ser.serialize_field("isTrue", v)?; - } - logical_expr_node::ExprType::IsFalse(v) => { - struct_ser.serialize_field("isFalse", v)?; - } - logical_expr_node::ExprType::IsUnknown(v) => { - struct_ser.serialize_field("isUnknown", v)?; - } - logical_expr_node::ExprType::IsNotTrue(v) => { - struct_ser.serialize_field("isNotTrue", v)?; - } - logical_expr_node::ExprType::IsNotFalse(v) => { - struct_ser.serialize_field("isNotFalse", v)?; - } - logical_expr_node::ExprType::IsNotUnknown(v) => { - struct_ser.serialize_field("isNotUnknown", v)?; - } - logical_expr_node::ExprType::Like(v) => { - struct_ser.serialize_field("like", v)?; - } - logical_expr_node::ExprType::Ilike(v) => { - struct_ser.serialize_field("ilike", v)?; - } - logical_expr_node::ExprType::SimilarTo(v) => { - struct_ser.serialize_field("similarTo", v)?; - } - logical_expr_node::ExprType::Placeholder(v) => { - struct_ser.serialize_field("placeholder", v)?; - } - } + let mut struct_ser = serializer.serialize_struct("datafusion.JsonWriterOptions", len)?; + if self.compression != 0 { + let v = CompressionTypeVariant::try_from(self.compression) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; + struct_ser.serialize_field("compression", &v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for LogicalExprNode { +impl<'de> serde::Deserialize<'de> for JsonWriterOptions { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "column", - "alias", - "literal", - "binary_expr", - "binaryExpr", - "aggregate_expr", - "aggregateExpr", - "is_null_expr", - "isNullExpr", - "is_not_null_expr", - "isNotNullExpr", - "not_expr", - "notExpr", - "between", - "case_", - "case", - "cast", - "sort", - "negative", - "in_list", - "inList", - "wildcard", - "scalar_function", - "scalarFunction", - "try_cast", - "tryCast", - "window_expr", - "windowExpr", - "aggregate_udf_expr", - "aggregateUdfExpr", - "scalar_udf_expr", - "scalarUdfExpr", - "get_indexed_field", - "getIndexedField", - "grouping_set", - "groupingSet", - "cube", - "rollup", - "is_true", - "isTrue", - "is_false", - "isFalse", - "is_unknown", - "isUnknown", - "is_not_true", - "isNotTrue", - "is_not_false", - "isNotFalse", - "is_not_unknown", - "isNotUnknown", - "like", - "ilike", - "similar_to", - "similarTo", - "placeholder", + "compression", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Column, - Alias, - Literal, - BinaryExpr, - AggregateExpr, - IsNullExpr, - IsNotNullExpr, - NotExpr, - Between, - Case, - Cast, - Sort, - Negative, - InList, - Wildcard, - ScalarFunction, - TryCast, - WindowExpr, - AggregateUdfExpr, - ScalarUdfExpr, - GetIndexedField, - GroupingSet, - Cube, - Rollup, - IsTrue, - IsFalse, - IsUnknown, - IsNotTrue, - IsNotFalse, - IsNotUnknown, - Like, - Ilike, - SimilarTo, - Placeholder, + Compression, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -11499,40 +11851,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { E: serde::de::Error, { match value { - "column" => Ok(GeneratedField::Column), - "alias" => Ok(GeneratedField::Alias), - "literal" => Ok(GeneratedField::Literal), - "binaryExpr" | "binary_expr" => Ok(GeneratedField::BinaryExpr), - "aggregateExpr" | "aggregate_expr" => Ok(GeneratedField::AggregateExpr), - "isNullExpr" | "is_null_expr" => Ok(GeneratedField::IsNullExpr), - "isNotNullExpr" | "is_not_null_expr" => Ok(GeneratedField::IsNotNullExpr), - "notExpr" | "not_expr" => Ok(GeneratedField::NotExpr), - "between" => Ok(GeneratedField::Between), - "case" | "case_" => Ok(GeneratedField::Case), - "cast" => Ok(GeneratedField::Cast), - "sort" => Ok(GeneratedField::Sort), - "negative" => Ok(GeneratedField::Negative), - "inList" | "in_list" => Ok(GeneratedField::InList), - "wildcard" => Ok(GeneratedField::Wildcard), - "scalarFunction" | "scalar_function" => Ok(GeneratedField::ScalarFunction), - "tryCast" | "try_cast" => Ok(GeneratedField::TryCast), - "windowExpr" | "window_expr" => Ok(GeneratedField::WindowExpr), - "aggregateUdfExpr" | "aggregate_udf_expr" => Ok(GeneratedField::AggregateUdfExpr), - "scalarUdfExpr" | "scalar_udf_expr" => Ok(GeneratedField::ScalarUdfExpr), - "getIndexedField" | "get_indexed_field" => Ok(GeneratedField::GetIndexedField), - "groupingSet" | "grouping_set" => Ok(GeneratedField::GroupingSet), - "cube" => Ok(GeneratedField::Cube), - "rollup" => Ok(GeneratedField::Rollup), - "isTrue" | "is_true" => Ok(GeneratedField::IsTrue), - "isFalse" | "is_false" => Ok(GeneratedField::IsFalse), - "isUnknown" | "is_unknown" => Ok(GeneratedField::IsUnknown), - "isNotTrue" | "is_not_true" => Ok(GeneratedField::IsNotTrue), - "isNotFalse" | "is_not_false" => Ok(GeneratedField::IsNotFalse), - "isNotUnknown" | "is_not_unknown" => Ok(GeneratedField::IsNotUnknown), - "like" => Ok(GeneratedField::Like), - "ilike" => Ok(GeneratedField::Ilike), - "similarTo" | "similar_to" => Ok(GeneratedField::SimilarTo), - "placeholder" => Ok(GeneratedField::Placeholder), + "compression" => Ok(GeneratedField::Compression), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -11542,267 +11861,2894 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LogicalExprNode; + type Value = JsonWriterOptions; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LogicalExprNode") + formatter.write_str("struct datafusion.JsonWriterOptions") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr_type__ = None; + let mut compression__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Column => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("column")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Column) -; - } - GeneratedField::Alias => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("alias")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Alias) -; - } - GeneratedField::Literal => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("literal")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Literal) -; - } - GeneratedField::BinaryExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("binaryExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::BinaryExpr) -; - } - GeneratedField::AggregateExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("aggregateExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::AggregateExpr) -; - } - GeneratedField::IsNullExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isNullExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNullExpr) -; - } - GeneratedField::IsNotNullExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isNotNullExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotNullExpr) -; - } - GeneratedField::NotExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("notExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::NotExpr) -; - } - GeneratedField::Between => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("between")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Between) -; - } - GeneratedField::Case => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("case")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Case) -; - } - GeneratedField::Cast => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("cast")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Cast) -; - } - GeneratedField::Sort => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("sort")); + GeneratedField::Compression => { + if compression__.is_some() { + return Err(serde::de::Error::duplicate_field("compression")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Sort) -; + compression__ = Some(map_.next_value::()? as i32); } - GeneratedField::Negative => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("negative")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Negative) -; + } + } + Ok(JsonWriterOptions { + compression: compression__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.JsonWriterOptions", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for LikeNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.negated { + len += 1; + } + if self.expr.is_some() { + len += 1; + } + if self.pattern.is_some() { + len += 1; + } + if !self.escape_char.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.LikeNode", len)?; + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; + } + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if let Some(v) = self.pattern.as_ref() { + struct_ser.serialize_field("pattern", v)?; + } + if !self.escape_char.is_empty() { + struct_ser.serialize_field("escapeChar", &self.escape_char)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for LikeNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "negated", + "expr", + "pattern", + "escape_char", + "escapeChar", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Negated, + Expr, + Pattern, + EscapeChar, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "negated" => Ok(GeneratedField::Negated), + "expr" => Ok(GeneratedField::Expr), + "pattern" => Ok(GeneratedField::Pattern), + "escapeChar" | "escape_char" => Ok(GeneratedField::EscapeChar), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } - GeneratedField::InList => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("inList")); + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = LikeNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.LikeNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut negated__ = None; + let mut expr__ = None; + let mut pattern__ = None; + let mut escape_char__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::InList) -; + negated__ = Some(map_.next_value()?); } - GeneratedField::Wildcard => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("wildcard")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Wildcard); + expr__ = map_.next_value()?; } - GeneratedField::ScalarFunction => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("scalarFunction")); + GeneratedField::Pattern => { + if pattern__.is_some() { + return Err(serde::de::Error::duplicate_field("pattern")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ScalarFunction) -; + pattern__ = map_.next_value()?; } - GeneratedField::TryCast => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("tryCast")); + GeneratedField::EscapeChar => { + if escape_char__.is_some() { + return Err(serde::de::Error::duplicate_field("escapeChar")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::TryCast) -; + escape_char__ = Some(map_.next_value()?); } - GeneratedField::WindowExpr => { - if expr_type__.is_some() { + } + } + Ok(LikeNode { + negated: negated__.unwrap_or_default(), + expr: expr__, + pattern: pattern__, + escape_char: escape_char__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.LikeNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for LimitNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.skip != 0 { + len += 1; + } + if self.fetch != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.LimitNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if self.skip != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("skip", ToString::to_string(&self.skip).as_str())?; + } + if self.fetch != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for LimitNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "skip", + "fetch", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Skip, + Fetch, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "skip" => Ok(GeneratedField::Skip), + "fetch" => Ok(GeneratedField::Fetch), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = LimitNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.LimitNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut skip__ = None; + let mut fetch__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Skip => { + if skip__.is_some() { + return Err(serde::de::Error::duplicate_field("skip")); + } + skip__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(LimitNode { + input: input__, + skip: skip__.unwrap_or_default(), + fetch: fetch__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.LimitNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for List { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.List", len)?; + if let Some(v) = self.field_type.as_ref() { + struct_ser.serialize_field("fieldType", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for List { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field_type", + "fieldType", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + FieldType, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "fieldType" | "field_type" => Ok(GeneratedField::FieldType), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = List; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.List") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field_type__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::FieldType => { + if field_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fieldType")); + } + field_type__ = map_.next_value()?; + } + } + } + Ok(List { + field_type: field_type__, + }) + } + } + deserializer.deserialize_struct("datafusion.List", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ListIndex { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.key.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListIndex", len)?; + if let Some(v) = self.key.as_ref() { + struct_ser.serialize_field("key", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ListIndex { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "key", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Key, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "key" => Ok(GeneratedField::Key), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ListIndex; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ListIndex") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut key__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Key => { + if key__.is_some() { + return Err(serde::de::Error::duplicate_field("key")); + } + key__ = map_.next_value()?; + } + } + } + Ok(ListIndex { + key: key__, + }) + } + } + deserializer.deserialize_struct("datafusion.ListIndex", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ListIndexExpr { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.key.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListIndexExpr", len)?; + if let Some(v) = self.key.as_ref() { + struct_ser.serialize_field("key", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ListIndexExpr { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "key", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Key, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "key" => Ok(GeneratedField::Key), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ListIndexExpr; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ListIndexExpr") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut key__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Key => { + if key__.is_some() { + return Err(serde::de::Error::duplicate_field("key")); + } + key__ = map_.next_value()?; + } + } + } + Ok(ListIndexExpr { + key: key__, + }) + } + } + deserializer.deserialize_struct("datafusion.ListIndexExpr", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ListRange { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.start.is_some() { + len += 1; + } + if self.stop.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListRange", len)?; + if let Some(v) = self.start.as_ref() { + struct_ser.serialize_field("start", v)?; + } + if let Some(v) = self.stop.as_ref() { + struct_ser.serialize_field("stop", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ListRange { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "start", + "stop", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Start, + Stop, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "start" => Ok(GeneratedField::Start), + "stop" => Ok(GeneratedField::Stop), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ListRange; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ListRange") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut start__ = None; + let mut stop__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Start => { + if start__.is_some() { + return Err(serde::de::Error::duplicate_field("start")); + } + start__ = map_.next_value()?; + } + GeneratedField::Stop => { + if stop__.is_some() { + return Err(serde::de::Error::duplicate_field("stop")); + } + stop__ = map_.next_value()?; + } + } + } + Ok(ListRange { + start: start__, + stop: stop__, + }) + } + } + deserializer.deserialize_struct("datafusion.ListRange", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ListRangeExpr { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.start.is_some() { + len += 1; + } + if self.stop.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListRangeExpr", len)?; + if let Some(v) = self.start.as_ref() { + struct_ser.serialize_field("start", v)?; + } + if let Some(v) = self.stop.as_ref() { + struct_ser.serialize_field("stop", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ListRangeExpr { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "start", + "stop", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Start, + Stop, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "start" => Ok(GeneratedField::Start), + "stop" => Ok(GeneratedField::Stop), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ListRangeExpr; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ListRangeExpr") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut start__ = None; + let mut stop__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Start => { + if start__.is_some() { + return Err(serde::de::Error::duplicate_field("start")); + } + start__ = map_.next_value()?; + } + GeneratedField::Stop => { + if stop__.is_some() { + return Err(serde::de::Error::duplicate_field("stop")); + } + stop__ = map_.next_value()?; + } + } + } + Ok(ListRangeExpr { + start: start__, + stop: stop__, + }) + } + } + deserializer.deserialize_struct("datafusion.ListRangeExpr", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ListingTableScanNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.table_name.is_some() { + len += 1; + } + if !self.paths.is_empty() { + len += 1; + } + if !self.file_extension.is_empty() { + len += 1; + } + if self.projection.is_some() { + len += 1; + } + if self.schema.is_some() { + len += 1; + } + if !self.filters.is_empty() { + len += 1; + } + if !self.table_partition_cols.is_empty() { + len += 1; + } + if self.collect_stat { + len += 1; + } + if self.target_partitions != 0 { + len += 1; + } + if !self.file_sort_order.is_empty() { + len += 1; + } + if self.file_format_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListingTableScanNode", len)?; + if let Some(v) = self.table_name.as_ref() { + struct_ser.serialize_field("tableName", v)?; + } + if !self.paths.is_empty() { + struct_ser.serialize_field("paths", &self.paths)?; + } + if !self.file_extension.is_empty() { + struct_ser.serialize_field("fileExtension", &self.file_extension)?; + } + if let Some(v) = self.projection.as_ref() { + struct_ser.serialize_field("projection", v)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + if !self.filters.is_empty() { + struct_ser.serialize_field("filters", &self.filters)?; + } + if !self.table_partition_cols.is_empty() { + struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; + } + if self.collect_stat { + struct_ser.serialize_field("collectStat", &self.collect_stat)?; + } + if self.target_partitions != 0 { + struct_ser.serialize_field("targetPartitions", &self.target_partitions)?; + } + if !self.file_sort_order.is_empty() { + struct_ser.serialize_field("fileSortOrder", &self.file_sort_order)?; + } + if let Some(v) = self.file_format_type.as_ref() { + match v { + listing_table_scan_node::FileFormatType::Csv(v) => { + struct_ser.serialize_field("csv", v)?; + } + listing_table_scan_node::FileFormatType::Parquet(v) => { + struct_ser.serialize_field("parquet", v)?; + } + listing_table_scan_node::FileFormatType::Avro(v) => { + struct_ser.serialize_field("avro", v)?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ListingTableScanNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "table_name", + "tableName", + "paths", + "file_extension", + "fileExtension", + "projection", + "schema", + "filters", + "table_partition_cols", + "tablePartitionCols", + "collect_stat", + "collectStat", + "target_partitions", + "targetPartitions", + "file_sort_order", + "fileSortOrder", + "csv", + "parquet", + "avro", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + TableName, + Paths, + FileExtension, + Projection, + Schema, + Filters, + TablePartitionCols, + CollectStat, + TargetPartitions, + FileSortOrder, + Csv, + Parquet, + Avro, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "tableName" | "table_name" => Ok(GeneratedField::TableName), + "paths" => Ok(GeneratedField::Paths), + "fileExtension" | "file_extension" => Ok(GeneratedField::FileExtension), + "projection" => Ok(GeneratedField::Projection), + "schema" => Ok(GeneratedField::Schema), + "filters" => Ok(GeneratedField::Filters), + "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), + "collectStat" | "collect_stat" => Ok(GeneratedField::CollectStat), + "targetPartitions" | "target_partitions" => Ok(GeneratedField::TargetPartitions), + "fileSortOrder" | "file_sort_order" => Ok(GeneratedField::FileSortOrder), + "csv" => Ok(GeneratedField::Csv), + "parquet" => Ok(GeneratedField::Parquet), + "avro" => Ok(GeneratedField::Avro), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ListingTableScanNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ListingTableScanNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut table_name__ = None; + let mut paths__ = None; + let mut file_extension__ = None; + let mut projection__ = None; + let mut schema__ = None; + let mut filters__ = None; + let mut table_partition_cols__ = None; + let mut collect_stat__ = None; + let mut target_partitions__ = None; + let mut file_sort_order__ = None; + let mut file_format_type__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::TableName => { + if table_name__.is_some() { + return Err(serde::de::Error::duplicate_field("tableName")); + } + table_name__ = map_.next_value()?; + } + GeneratedField::Paths => { + if paths__.is_some() { + return Err(serde::de::Error::duplicate_field("paths")); + } + paths__ = Some(map_.next_value()?); + } + GeneratedField::FileExtension => { + if file_extension__.is_some() { + return Err(serde::de::Error::duplicate_field("fileExtension")); + } + file_extension__ = Some(map_.next_value()?); + } + GeneratedField::Projection => { + if projection__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); + } + projection__ = map_.next_value()?; + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + GeneratedField::Filters => { + if filters__.is_some() { + return Err(serde::de::Error::duplicate_field("filters")); + } + filters__ = Some(map_.next_value()?); + } + GeneratedField::TablePartitionCols => { + if table_partition_cols__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePartitionCols")); + } + table_partition_cols__ = Some(map_.next_value()?); + } + GeneratedField::CollectStat => { + if collect_stat__.is_some() { + return Err(serde::de::Error::duplicate_field("collectStat")); + } + collect_stat__ = Some(map_.next_value()?); + } + GeneratedField::TargetPartitions => { + if target_partitions__.is_some() { + return Err(serde::de::Error::duplicate_field("targetPartitions")); + } + target_partitions__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::FileSortOrder => { + if file_sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("fileSortOrder")); + } + file_sort_order__ = Some(map_.next_value()?); + } + GeneratedField::Csv => { + if file_format_type__.is_some() { + return Err(serde::de::Error::duplicate_field("csv")); + } + file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Csv) +; + } + GeneratedField::Parquet => { + if file_format_type__.is_some() { + return Err(serde::de::Error::duplicate_field("parquet")); + } + file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Parquet) +; + } + GeneratedField::Avro => { + if file_format_type__.is_some() { + return Err(serde::de::Error::duplicate_field("avro")); + } + file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Avro) +; + } + } + } + Ok(ListingTableScanNode { + table_name: table_name__, + paths: paths__.unwrap_or_default(), + file_extension: file_extension__.unwrap_or_default(), + projection: projection__, + schema: schema__, + filters: filters__.unwrap_or_default(), + table_partition_cols: table_partition_cols__.unwrap_or_default(), + collect_stat: collect_stat__.unwrap_or_default(), + target_partitions: target_partitions__.unwrap_or_default(), + file_sort_order: file_sort_order__.unwrap_or_default(), + file_format_type: file_format_type__, + }) + } + } + deserializer.deserialize_struct("datafusion.ListingTableScanNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for LocalLimitExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.fetch != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.LocalLimitExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if self.fetch != 0 { + struct_ser.serialize_field("fetch", &self.fetch)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for LocalLimitExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "fetch", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Fetch, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "fetch" => Ok(GeneratedField::Fetch), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = LocalLimitExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.LocalLimitExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut fetch__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(LocalLimitExecNode { + input: input__, + fetch: fetch__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.LocalLimitExecNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for LogicalExprList { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.expr.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExprList", len)?; + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for LogicalExprList { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "expr", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Expr, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "expr" => Ok(GeneratedField::Expr), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = LogicalExprList; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.LogicalExprList") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut expr__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = Some(map_.next_value()?); + } + } + } + Ok(LogicalExprList { + expr: expr__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.LogicalExprList", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for LogicalExprNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.expr_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExprNode", len)?; + if let Some(v) = self.expr_type.as_ref() { + match v { + logical_expr_node::ExprType::Column(v) => { + struct_ser.serialize_field("column", v)?; + } + logical_expr_node::ExprType::Alias(v) => { + struct_ser.serialize_field("alias", v)?; + } + logical_expr_node::ExprType::Literal(v) => { + struct_ser.serialize_field("literal", v)?; + } + logical_expr_node::ExprType::BinaryExpr(v) => { + struct_ser.serialize_field("binaryExpr", v)?; + } + logical_expr_node::ExprType::AggregateExpr(v) => { + struct_ser.serialize_field("aggregateExpr", v)?; + } + logical_expr_node::ExprType::IsNullExpr(v) => { + struct_ser.serialize_field("isNullExpr", v)?; + } + logical_expr_node::ExprType::IsNotNullExpr(v) => { + struct_ser.serialize_field("isNotNullExpr", v)?; + } + logical_expr_node::ExprType::NotExpr(v) => { + struct_ser.serialize_field("notExpr", v)?; + } + logical_expr_node::ExprType::Between(v) => { + struct_ser.serialize_field("between", v)?; + } + logical_expr_node::ExprType::Case(v) => { + struct_ser.serialize_field("case", v)?; + } + logical_expr_node::ExprType::Cast(v) => { + struct_ser.serialize_field("cast", v)?; + } + logical_expr_node::ExprType::Sort(v) => { + struct_ser.serialize_field("sort", v)?; + } + logical_expr_node::ExprType::Negative(v) => { + struct_ser.serialize_field("negative", v)?; + } + logical_expr_node::ExprType::InList(v) => { + struct_ser.serialize_field("inList", v)?; + } + logical_expr_node::ExprType::Wildcard(v) => { + struct_ser.serialize_field("wildcard", v)?; + } + logical_expr_node::ExprType::ScalarFunction(v) => { + struct_ser.serialize_field("scalarFunction", v)?; + } + logical_expr_node::ExprType::TryCast(v) => { + struct_ser.serialize_field("tryCast", v)?; + } + logical_expr_node::ExprType::WindowExpr(v) => { + struct_ser.serialize_field("windowExpr", v)?; + } + logical_expr_node::ExprType::AggregateUdfExpr(v) => { + struct_ser.serialize_field("aggregateUdfExpr", v)?; + } + logical_expr_node::ExprType::ScalarUdfExpr(v) => { + struct_ser.serialize_field("scalarUdfExpr", v)?; + } + logical_expr_node::ExprType::GetIndexedField(v) => { + struct_ser.serialize_field("getIndexedField", v)?; + } + logical_expr_node::ExprType::GroupingSet(v) => { + struct_ser.serialize_field("groupingSet", v)?; + } + logical_expr_node::ExprType::Cube(v) => { + struct_ser.serialize_field("cube", v)?; + } + logical_expr_node::ExprType::Rollup(v) => { + struct_ser.serialize_field("rollup", v)?; + } + logical_expr_node::ExprType::IsTrue(v) => { + struct_ser.serialize_field("isTrue", v)?; + } + logical_expr_node::ExprType::IsFalse(v) => { + struct_ser.serialize_field("isFalse", v)?; + } + logical_expr_node::ExprType::IsUnknown(v) => { + struct_ser.serialize_field("isUnknown", v)?; + } + logical_expr_node::ExprType::IsNotTrue(v) => { + struct_ser.serialize_field("isNotTrue", v)?; + } + logical_expr_node::ExprType::IsNotFalse(v) => { + struct_ser.serialize_field("isNotFalse", v)?; + } + logical_expr_node::ExprType::IsNotUnknown(v) => { + struct_ser.serialize_field("isNotUnknown", v)?; + } + logical_expr_node::ExprType::Like(v) => { + struct_ser.serialize_field("like", v)?; + } + logical_expr_node::ExprType::Ilike(v) => { + struct_ser.serialize_field("ilike", v)?; + } + logical_expr_node::ExprType::SimilarTo(v) => { + struct_ser.serialize_field("similarTo", v)?; + } + logical_expr_node::ExprType::Placeholder(v) => { + struct_ser.serialize_field("placeholder", v)?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for LogicalExprNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "column", + "alias", + "literal", + "binary_expr", + "binaryExpr", + "aggregate_expr", + "aggregateExpr", + "is_null_expr", + "isNullExpr", + "is_not_null_expr", + "isNotNullExpr", + "not_expr", + "notExpr", + "between", + "case_", + "case", + "cast", + "sort", + "negative", + "in_list", + "inList", + "wildcard", + "scalar_function", + "scalarFunction", + "try_cast", + "tryCast", + "window_expr", + "windowExpr", + "aggregate_udf_expr", + "aggregateUdfExpr", + "scalar_udf_expr", + "scalarUdfExpr", + "get_indexed_field", + "getIndexedField", + "grouping_set", + "groupingSet", + "cube", + "rollup", + "is_true", + "isTrue", + "is_false", + "isFalse", + "is_unknown", + "isUnknown", + "is_not_true", + "isNotTrue", + "is_not_false", + "isNotFalse", + "is_not_unknown", + "isNotUnknown", + "like", + "ilike", + "similar_to", + "similarTo", + "placeholder", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Column, + Alias, + Literal, + BinaryExpr, + AggregateExpr, + IsNullExpr, + IsNotNullExpr, + NotExpr, + Between, + Case, + Cast, + Sort, + Negative, + InList, + Wildcard, + ScalarFunction, + TryCast, + WindowExpr, + AggregateUdfExpr, + ScalarUdfExpr, + GetIndexedField, + GroupingSet, + Cube, + Rollup, + IsTrue, + IsFalse, + IsUnknown, + IsNotTrue, + IsNotFalse, + IsNotUnknown, + Like, + Ilike, + SimilarTo, + Placeholder, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "column" => Ok(GeneratedField::Column), + "alias" => Ok(GeneratedField::Alias), + "literal" => Ok(GeneratedField::Literal), + "binaryExpr" | "binary_expr" => Ok(GeneratedField::BinaryExpr), + "aggregateExpr" | "aggregate_expr" => Ok(GeneratedField::AggregateExpr), + "isNullExpr" | "is_null_expr" => Ok(GeneratedField::IsNullExpr), + "isNotNullExpr" | "is_not_null_expr" => Ok(GeneratedField::IsNotNullExpr), + "notExpr" | "not_expr" => Ok(GeneratedField::NotExpr), + "between" => Ok(GeneratedField::Between), + "case" | "case_" => Ok(GeneratedField::Case), + "cast" => Ok(GeneratedField::Cast), + "sort" => Ok(GeneratedField::Sort), + "negative" => Ok(GeneratedField::Negative), + "inList" | "in_list" => Ok(GeneratedField::InList), + "wildcard" => Ok(GeneratedField::Wildcard), + "scalarFunction" | "scalar_function" => Ok(GeneratedField::ScalarFunction), + "tryCast" | "try_cast" => Ok(GeneratedField::TryCast), + "windowExpr" | "window_expr" => Ok(GeneratedField::WindowExpr), + "aggregateUdfExpr" | "aggregate_udf_expr" => Ok(GeneratedField::AggregateUdfExpr), + "scalarUdfExpr" | "scalar_udf_expr" => Ok(GeneratedField::ScalarUdfExpr), + "getIndexedField" | "get_indexed_field" => Ok(GeneratedField::GetIndexedField), + "groupingSet" | "grouping_set" => Ok(GeneratedField::GroupingSet), + "cube" => Ok(GeneratedField::Cube), + "rollup" => Ok(GeneratedField::Rollup), + "isTrue" | "is_true" => Ok(GeneratedField::IsTrue), + "isFalse" | "is_false" => Ok(GeneratedField::IsFalse), + "isUnknown" | "is_unknown" => Ok(GeneratedField::IsUnknown), + "isNotTrue" | "is_not_true" => Ok(GeneratedField::IsNotTrue), + "isNotFalse" | "is_not_false" => Ok(GeneratedField::IsNotFalse), + "isNotUnknown" | "is_not_unknown" => Ok(GeneratedField::IsNotUnknown), + "like" => Ok(GeneratedField::Like), + "ilike" => Ok(GeneratedField::Ilike), + "similarTo" | "similar_to" => Ok(GeneratedField::SimilarTo), + "placeholder" => Ok(GeneratedField::Placeholder), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = LogicalExprNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.LogicalExprNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut expr_type__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Column => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("column")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Column) +; + } + GeneratedField::Alias => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("alias")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Alias) +; + } + GeneratedField::Literal => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("literal")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Literal) +; + } + GeneratedField::BinaryExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("binaryExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::BinaryExpr) +; + } + GeneratedField::AggregateExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("aggregateExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::AggregateExpr) +; + } + GeneratedField::IsNullExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isNullExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNullExpr) +; + } + GeneratedField::IsNotNullExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isNotNullExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotNullExpr) +; + } + GeneratedField::NotExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("notExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::NotExpr) +; + } + GeneratedField::Between => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("between")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Between) +; + } + GeneratedField::Case => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("case")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Case) +; + } + GeneratedField::Cast => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("cast")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Cast) +; + } + GeneratedField::Sort => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("sort")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Sort) +; + } + GeneratedField::Negative => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("negative")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Negative) +; + } + GeneratedField::InList => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("inList")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::InList) +; + } + GeneratedField::Wildcard => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("wildcard")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Wildcard) +; + } + GeneratedField::ScalarFunction => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarFunction")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ScalarFunction) +; + } + GeneratedField::TryCast => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("tryCast")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::TryCast) +; + } + GeneratedField::WindowExpr => { + if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("windowExpr")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::WindowExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::WindowExpr) +; + } + GeneratedField::AggregateUdfExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("aggregateUdfExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::AggregateUdfExpr) +; + } + GeneratedField::ScalarUdfExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarUdfExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ScalarUdfExpr) +; + } + GeneratedField::GetIndexedField => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("getIndexedField")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::GetIndexedField) +; + } + GeneratedField::GroupingSet => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("groupingSet")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::GroupingSet) +; + } + GeneratedField::Cube => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("cube")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Cube) +; + } + GeneratedField::Rollup => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("rollup")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Rollup) +; + } + GeneratedField::IsTrue => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isTrue")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsTrue) +; + } + GeneratedField::IsFalse => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isFalse")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsFalse) +; + } + GeneratedField::IsUnknown => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isUnknown")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsUnknown) +; + } + GeneratedField::IsNotTrue => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isNotTrue")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotTrue) +; + } + GeneratedField::IsNotFalse => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isNotFalse")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotFalse) +; + } + GeneratedField::IsNotUnknown => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isNotUnknown")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotUnknown) +; + } + GeneratedField::Like => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("like")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Like) +; + } + GeneratedField::Ilike => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("ilike")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Ilike) +; + } + GeneratedField::SimilarTo => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("similarTo")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::SimilarTo) +; + } + GeneratedField::Placeholder => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("placeholder")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Placeholder) +; + } + } + } + Ok(LogicalExprNode { + expr_type: expr_type__, + }) + } + } + deserializer.deserialize_struct("datafusion.LogicalExprNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for LogicalExprNodeCollection { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.logical_expr_nodes.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExprNodeCollection", len)?; + if !self.logical_expr_nodes.is_empty() { + struct_ser.serialize_field("logicalExprNodes", &self.logical_expr_nodes)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for LogicalExprNodeCollection { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "logical_expr_nodes", + "logicalExprNodes", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + LogicalExprNodes, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "logicalExprNodes" | "logical_expr_nodes" => Ok(GeneratedField::LogicalExprNodes), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = LogicalExprNodeCollection; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.LogicalExprNodeCollection") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut logical_expr_nodes__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::LogicalExprNodes => { + if logical_expr_nodes__.is_some() { + return Err(serde::de::Error::duplicate_field("logicalExprNodes")); + } + logical_expr_nodes__ = Some(map_.next_value()?); + } + } + } + Ok(LogicalExprNodeCollection { + logical_expr_nodes: logical_expr_nodes__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.LogicalExprNodeCollection", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for LogicalExtensionNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.node.is_empty() { + len += 1; + } + if !self.inputs.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExtensionNode", len)?; + if !self.node.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("node", pbjson::private::base64::encode(&self.node).as_str())?; + } + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for LogicalExtensionNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "node", + "inputs", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Node, + Inputs, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "node" => Ok(GeneratedField::Node), + "inputs" => Ok(GeneratedField::Inputs), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = LogicalExtensionNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.LogicalExtensionNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut node__ = None; + let mut inputs__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Node => { + if node__.is_some() { + return Err(serde::de::Error::duplicate_field("node")); + } + node__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); + } + inputs__ = Some(map_.next_value()?); + } + } + } + Ok(LogicalExtensionNode { + node: node__.unwrap_or_default(), + inputs: inputs__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.LogicalExtensionNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for LogicalPlanNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.logical_plan_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.LogicalPlanNode", len)?; + if let Some(v) = self.logical_plan_type.as_ref() { + match v { + logical_plan_node::LogicalPlanType::ListingScan(v) => { + struct_ser.serialize_field("listingScan", v)?; + } + logical_plan_node::LogicalPlanType::Projection(v) => { + struct_ser.serialize_field("projection", v)?; + } + logical_plan_node::LogicalPlanType::Selection(v) => { + struct_ser.serialize_field("selection", v)?; + } + logical_plan_node::LogicalPlanType::Limit(v) => { + struct_ser.serialize_field("limit", v)?; + } + logical_plan_node::LogicalPlanType::Aggregate(v) => { + struct_ser.serialize_field("aggregate", v)?; + } + logical_plan_node::LogicalPlanType::Join(v) => { + struct_ser.serialize_field("join", v)?; + } + logical_plan_node::LogicalPlanType::Sort(v) => { + struct_ser.serialize_field("sort", v)?; + } + logical_plan_node::LogicalPlanType::Repartition(v) => { + struct_ser.serialize_field("repartition", v)?; + } + logical_plan_node::LogicalPlanType::EmptyRelation(v) => { + struct_ser.serialize_field("emptyRelation", v)?; + } + logical_plan_node::LogicalPlanType::CreateExternalTable(v) => { + struct_ser.serialize_field("createExternalTable", v)?; + } + logical_plan_node::LogicalPlanType::Explain(v) => { + struct_ser.serialize_field("explain", v)?; + } + logical_plan_node::LogicalPlanType::Window(v) => { + struct_ser.serialize_field("window", v)?; + } + logical_plan_node::LogicalPlanType::Analyze(v) => { + struct_ser.serialize_field("analyze", v)?; + } + logical_plan_node::LogicalPlanType::CrossJoin(v) => { + struct_ser.serialize_field("crossJoin", v)?; + } + logical_plan_node::LogicalPlanType::Values(v) => { + struct_ser.serialize_field("values", v)?; + } + logical_plan_node::LogicalPlanType::Extension(v) => { + struct_ser.serialize_field("extension", v)?; + } + logical_plan_node::LogicalPlanType::CreateCatalogSchema(v) => { + struct_ser.serialize_field("createCatalogSchema", v)?; + } + logical_plan_node::LogicalPlanType::Union(v) => { + struct_ser.serialize_field("union", v)?; + } + logical_plan_node::LogicalPlanType::CreateCatalog(v) => { + struct_ser.serialize_field("createCatalog", v)?; + } + logical_plan_node::LogicalPlanType::SubqueryAlias(v) => { + struct_ser.serialize_field("subqueryAlias", v)?; + } + logical_plan_node::LogicalPlanType::CreateView(v) => { + struct_ser.serialize_field("createView", v)?; + } + logical_plan_node::LogicalPlanType::Distinct(v) => { + struct_ser.serialize_field("distinct", v)?; + } + logical_plan_node::LogicalPlanType::ViewScan(v) => { + struct_ser.serialize_field("viewScan", v)?; + } + logical_plan_node::LogicalPlanType::CustomScan(v) => { + struct_ser.serialize_field("customScan", v)?; + } + logical_plan_node::LogicalPlanType::Prepare(v) => { + struct_ser.serialize_field("prepare", v)?; + } + logical_plan_node::LogicalPlanType::DropView(v) => { + struct_ser.serialize_field("dropView", v)?; + } + logical_plan_node::LogicalPlanType::DistinctOn(v) => { + struct_ser.serialize_field("distinctOn", v)?; + } + logical_plan_node::LogicalPlanType::CopyTo(v) => { + struct_ser.serialize_field("copyTo", v)?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for LogicalPlanNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "listing_scan", + "listingScan", + "projection", + "selection", + "limit", + "aggregate", + "join", + "sort", + "repartition", + "empty_relation", + "emptyRelation", + "create_external_table", + "createExternalTable", + "explain", + "window", + "analyze", + "cross_join", + "crossJoin", + "values", + "extension", + "create_catalog_schema", + "createCatalogSchema", + "union", + "create_catalog", + "createCatalog", + "subquery_alias", + "subqueryAlias", + "create_view", + "createView", + "distinct", + "view_scan", + "viewScan", + "custom_scan", + "customScan", + "prepare", + "drop_view", + "dropView", + "distinct_on", + "distinctOn", + "copy_to", + "copyTo", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + ListingScan, + Projection, + Selection, + Limit, + Aggregate, + Join, + Sort, + Repartition, + EmptyRelation, + CreateExternalTable, + Explain, + Window, + Analyze, + CrossJoin, + Values, + Extension, + CreateCatalogSchema, + Union, + CreateCatalog, + SubqueryAlias, + CreateView, + Distinct, + ViewScan, + CustomScan, + Prepare, + DropView, + DistinctOn, + CopyTo, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "listingScan" | "listing_scan" => Ok(GeneratedField::ListingScan), + "projection" => Ok(GeneratedField::Projection), + "selection" => Ok(GeneratedField::Selection), + "limit" => Ok(GeneratedField::Limit), + "aggregate" => Ok(GeneratedField::Aggregate), + "join" => Ok(GeneratedField::Join), + "sort" => Ok(GeneratedField::Sort), + "repartition" => Ok(GeneratedField::Repartition), + "emptyRelation" | "empty_relation" => Ok(GeneratedField::EmptyRelation), + "createExternalTable" | "create_external_table" => Ok(GeneratedField::CreateExternalTable), + "explain" => Ok(GeneratedField::Explain), + "window" => Ok(GeneratedField::Window), + "analyze" => Ok(GeneratedField::Analyze), + "crossJoin" | "cross_join" => Ok(GeneratedField::CrossJoin), + "values" => Ok(GeneratedField::Values), + "extension" => Ok(GeneratedField::Extension), + "createCatalogSchema" | "create_catalog_schema" => Ok(GeneratedField::CreateCatalogSchema), + "union" => Ok(GeneratedField::Union), + "createCatalog" | "create_catalog" => Ok(GeneratedField::CreateCatalog), + "subqueryAlias" | "subquery_alias" => Ok(GeneratedField::SubqueryAlias), + "createView" | "create_view" => Ok(GeneratedField::CreateView), + "distinct" => Ok(GeneratedField::Distinct), + "viewScan" | "view_scan" => Ok(GeneratedField::ViewScan), + "customScan" | "custom_scan" => Ok(GeneratedField::CustomScan), + "prepare" => Ok(GeneratedField::Prepare), + "dropView" | "drop_view" => Ok(GeneratedField::DropView), + "distinctOn" | "distinct_on" => Ok(GeneratedField::DistinctOn), + "copyTo" | "copy_to" => Ok(GeneratedField::CopyTo), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = LogicalPlanNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.LogicalPlanNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut logical_plan_type__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::ListingScan => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("listingScan")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::ListingScan) +; + } + GeneratedField::Projection => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Projection) +; + } + GeneratedField::Selection => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("selection")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Selection) +; + } + GeneratedField::Limit => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("limit")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Limit) +; + } + GeneratedField::Aggregate => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("aggregate")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Aggregate) +; + } + GeneratedField::Join => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("join")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Join) +; + } + GeneratedField::Sort => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("sort")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Sort) +; + } + GeneratedField::Repartition => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("repartition")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Repartition) +; + } + GeneratedField::EmptyRelation => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("emptyRelation")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::EmptyRelation) +; + } + GeneratedField::CreateExternalTable => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("createExternalTable")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateExternalTable) +; + } + GeneratedField::Explain => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("explain")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Explain) +; + } + GeneratedField::Window => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("window")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Window) +; + } + GeneratedField::Analyze => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("analyze")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Analyze) +; + } + GeneratedField::CrossJoin => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("crossJoin")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CrossJoin) ; } - GeneratedField::AggregateUdfExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("aggregateUdfExpr")); + GeneratedField::Values => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("values")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::AggregateUdfExpr) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Values) ; } - GeneratedField::ScalarUdfExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("scalarUdfExpr")); + GeneratedField::Extension => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("extension")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ScalarUdfExpr) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Extension) ; } - GeneratedField::GetIndexedField => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("getIndexedField")); + GeneratedField::CreateCatalogSchema => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("createCatalogSchema")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::GetIndexedField) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateCatalogSchema) ; } - GeneratedField::GroupingSet => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("groupingSet")); + GeneratedField::Union => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("union")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::GroupingSet) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Union) ; } - GeneratedField::Cube => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("cube")); + GeneratedField::CreateCatalog => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("createCatalog")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Cube) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateCatalog) ; } - GeneratedField::Rollup => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("rollup")); + GeneratedField::SubqueryAlias => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("subqueryAlias")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Rollup) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::SubqueryAlias) ; } - GeneratedField::IsTrue => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isTrue")); + GeneratedField::CreateView => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("createView")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsTrue) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateView) ; } - GeneratedField::IsFalse => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isFalse")); + GeneratedField::Distinct => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("distinct")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsFalse) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Distinct) ; } - GeneratedField::IsUnknown => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isUnknown")); + GeneratedField::ViewScan => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("viewScan")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsUnknown) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::ViewScan) ; } - GeneratedField::IsNotTrue => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isNotTrue")); + GeneratedField::CustomScan => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("customScan")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotTrue) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CustomScan) ; } - GeneratedField::IsNotFalse => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isNotFalse")); + GeneratedField::Prepare => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("prepare")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotFalse) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Prepare) +; + } + GeneratedField::DropView => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("dropView")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DropView) +; + } + GeneratedField::DistinctOn => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("distinctOn")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DistinctOn) +; + } + GeneratedField::CopyTo => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("copyTo")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CopyTo) ; } - GeneratedField::IsNotUnknown => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isNotUnknown")); + } + } + Ok(LogicalPlanNode { + logical_plan_type: logical_plan_type__, + }) + } + } + deserializer.deserialize_struct("datafusion.LogicalPlanNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Map { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field_type.is_some() { + len += 1; + } + if self.keys_sorted { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.Map", len)?; + if let Some(v) = self.field_type.as_ref() { + struct_ser.serialize_field("fieldType", v)?; + } + if self.keys_sorted { + struct_ser.serialize_field("keysSorted", &self.keys_sorted)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Map { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field_type", + "fieldType", + "keys_sorted", + "keysSorted", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + FieldType, + KeysSorted, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "fieldType" | "field_type" => Ok(GeneratedField::FieldType), + "keysSorted" | "keys_sorted" => Ok(GeneratedField::KeysSorted), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Map; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.Map") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field_type__ = None; + let mut keys_sorted__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::FieldType => { + if field_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fieldType")); + } + field_type__ = map_.next_value()?; + } + GeneratedField::KeysSorted => { + if keys_sorted__.is_some() { + return Err(serde::de::Error::duplicate_field("keysSorted")); + } + keys_sorted__ = Some(map_.next_value()?); + } + } + } + Ok(Map { + field_type: field_type__, + keys_sorted: keys_sorted__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.Map", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for MaybeFilter { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.expr.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.MaybeFilter", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for MaybeFilter { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "expr", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Expr, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "expr" => Ok(GeneratedField::Expr), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = MaybeFilter; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.MaybeFilter") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut expr__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotUnknown) -; + expr__ = map_.next_value()?; } - GeneratedField::Like => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("like")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Like) -; + } + } + Ok(MaybeFilter { + expr: expr__, + }) + } + } + deserializer.deserialize_struct("datafusion.MaybeFilter", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for MaybePhysicalSortExprs { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.sort_expr.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.MaybePhysicalSortExprs", len)?; + if !self.sort_expr.is_empty() { + struct_ser.serialize_field("sortExpr", &self.sort_expr)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for MaybePhysicalSortExprs { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "sort_expr", + "sortExpr", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + SortExpr, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "sortExpr" | "sort_expr" => Ok(GeneratedField::SortExpr), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } - GeneratedField::Ilike => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("ilike")); + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = MaybePhysicalSortExprs; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.MaybePhysicalSortExprs") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut sort_expr__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::SortExpr => { + if sort_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("sortExpr")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Ilike) -; + sort_expr__ = Some(map_.next_value()?); } - GeneratedField::SimilarTo => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("similarTo")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::SimilarTo) -; + } + } + Ok(MaybePhysicalSortExprs { + sort_expr: sort_expr__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.MaybePhysicalSortExprs", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for NamedStructField { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.name.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.NamedStructField", len)?; + if let Some(v) = self.name.as_ref() { + struct_ser.serialize_field("name", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for NamedStructField { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "name", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Name, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "name" => Ok(GeneratedField::Name), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } - GeneratedField::Placeholder => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("placeholder")); + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = NamedStructField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.NamedStructField") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut name__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Placeholder) -; + name__ = map_.next_value()?; } } - } - Ok(LogicalExprNode { - expr_type: expr_type__, + } + Ok(NamedStructField { + name: name__, }) } } - deserializer.deserialize_struct("datafusion.LogicalExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.NamedStructField", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for LogicalExprNodeCollection { +impl serde::Serialize for NamedStructFieldExpr { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -11810,30 +14756,29 @@ impl serde::Serialize for LogicalExprNodeCollection { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.logical_expr_nodes.is_empty() { + if self.name.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExprNodeCollection", len)?; - if !self.logical_expr_nodes.is_empty() { - struct_ser.serialize_field("logicalExprNodes", &self.logical_expr_nodes)?; + let mut struct_ser = serializer.serialize_struct("datafusion.NamedStructFieldExpr", len)?; + if let Some(v) = self.name.as_ref() { + struct_ser.serialize_field("name", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for LogicalExprNodeCollection { +impl<'de> serde::Deserialize<'de> for NamedStructFieldExpr { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "logical_expr_nodes", - "logicalExprNodes", + "name", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - LogicalExprNodes, + Name, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -11855,7 +14800,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNodeCollection { E: serde::de::Error, { match value { - "logicalExprNodes" | "logical_expr_nodes" => Ok(GeneratedField::LogicalExprNodes), + "name" => Ok(GeneratedField::Name), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -11865,36 +14810,36 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNodeCollection { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LogicalExprNodeCollection; + type Value = NamedStructFieldExpr; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LogicalExprNodeCollection") + formatter.write_str("struct datafusion.NamedStructFieldExpr") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut logical_expr_nodes__ = None; + let mut name__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::LogicalExprNodes => { - if logical_expr_nodes__.is_some() { - return Err(serde::de::Error::duplicate_field("logicalExprNodes")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - logical_expr_nodes__ = Some(map_.next_value()?); + name__ = map_.next_value()?; } } } - Ok(LogicalExprNodeCollection { - logical_expr_nodes: logical_expr_nodes__.unwrap_or_default(), + Ok(NamedStructFieldExpr { + name: name__, }) } } - deserializer.deserialize_struct("datafusion.LogicalExprNodeCollection", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.NamedStructFieldExpr", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for LogicalExtensionNode { +impl serde::Serialize for NegativeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -11902,38 +14847,29 @@ impl serde::Serialize for LogicalExtensionNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.node.is_empty() { - len += 1; - } - if !self.inputs.is_empty() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExtensionNode", len)?; - if !self.node.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("node", pbjson::private::base64::encode(&self.node).as_str())?; - } - if !self.inputs.is_empty() { - struct_ser.serialize_field("inputs", &self.inputs)?; + let mut struct_ser = serializer.serialize_struct("datafusion.NegativeNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for LogicalExtensionNode { +impl<'de> serde::Deserialize<'de> for NegativeNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "node", - "inputs", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Node, - Inputs, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -11955,8 +14891,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExtensionNode { E: serde::de::Error, { match value { - "node" => Ok(GeneratedField::Node), - "inputs" => Ok(GeneratedField::Inputs), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -11966,46 +14901,36 @@ impl<'de> serde::Deserialize<'de> for LogicalExtensionNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LogicalExtensionNode; + type Value = NegativeNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LogicalExtensionNode") + formatter.write_str("struct datafusion.NegativeNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut node__ = None; - let mut inputs__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Node => { - if node__.is_some() { - return Err(serde::de::Error::duplicate_field("node")); - } - node__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) - ; - } - GeneratedField::Inputs => { - if inputs__.is_some() { - return Err(serde::de::Error::duplicate_field("inputs")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - inputs__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } } } - Ok(LogicalExtensionNode { - node: node__.unwrap_or_default(), - inputs: inputs__.unwrap_or_default(), + Ok(NegativeNode { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.LogicalExtensionNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.NegativeNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for LogicalPlanNode { +impl serde::Serialize for NestedLoopJoinExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -12013,169 +14938,56 @@ impl serde::Serialize for LogicalPlanNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.logical_plan_type.is_some() { + if self.left.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.LogicalPlanNode", len)?; - if let Some(v) = self.logical_plan_type.as_ref() { - match v { - logical_plan_node::LogicalPlanType::ListingScan(v) => { - struct_ser.serialize_field("listingScan", v)?; - } - logical_plan_node::LogicalPlanType::Projection(v) => { - struct_ser.serialize_field("projection", v)?; - } - logical_plan_node::LogicalPlanType::Selection(v) => { - struct_ser.serialize_field("selection", v)?; - } - logical_plan_node::LogicalPlanType::Limit(v) => { - struct_ser.serialize_field("limit", v)?; - } - logical_plan_node::LogicalPlanType::Aggregate(v) => { - struct_ser.serialize_field("aggregate", v)?; - } - logical_plan_node::LogicalPlanType::Join(v) => { - struct_ser.serialize_field("join", v)?; - } - logical_plan_node::LogicalPlanType::Sort(v) => { - struct_ser.serialize_field("sort", v)?; - } - logical_plan_node::LogicalPlanType::Repartition(v) => { - struct_ser.serialize_field("repartition", v)?; - } - logical_plan_node::LogicalPlanType::EmptyRelation(v) => { - struct_ser.serialize_field("emptyRelation", v)?; - } - logical_plan_node::LogicalPlanType::CreateExternalTable(v) => { - struct_ser.serialize_field("createExternalTable", v)?; - } - logical_plan_node::LogicalPlanType::Explain(v) => { - struct_ser.serialize_field("explain", v)?; - } - logical_plan_node::LogicalPlanType::Window(v) => { - struct_ser.serialize_field("window", v)?; - } - logical_plan_node::LogicalPlanType::Analyze(v) => { - struct_ser.serialize_field("analyze", v)?; - } - logical_plan_node::LogicalPlanType::CrossJoin(v) => { - struct_ser.serialize_field("crossJoin", v)?; - } - logical_plan_node::LogicalPlanType::Values(v) => { - struct_ser.serialize_field("values", v)?; - } - logical_plan_node::LogicalPlanType::Extension(v) => { - struct_ser.serialize_field("extension", v)?; - } - logical_plan_node::LogicalPlanType::CreateCatalogSchema(v) => { - struct_ser.serialize_field("createCatalogSchema", v)?; - } - logical_plan_node::LogicalPlanType::Union(v) => { - struct_ser.serialize_field("union", v)?; - } - logical_plan_node::LogicalPlanType::CreateCatalog(v) => { - struct_ser.serialize_field("createCatalog", v)?; - } - logical_plan_node::LogicalPlanType::SubqueryAlias(v) => { - struct_ser.serialize_field("subqueryAlias", v)?; - } - logical_plan_node::LogicalPlanType::CreateView(v) => { - struct_ser.serialize_field("createView", v)?; - } - logical_plan_node::LogicalPlanType::Distinct(v) => { - struct_ser.serialize_field("distinct", v)?; - } - logical_plan_node::LogicalPlanType::ViewScan(v) => { - struct_ser.serialize_field("viewScan", v)?; - } - logical_plan_node::LogicalPlanType::CustomScan(v) => { - struct_ser.serialize_field("customScan", v)?; - } - logical_plan_node::LogicalPlanType::Prepare(v) => { - struct_ser.serialize_field("prepare", v)?; - } - logical_plan_node::LogicalPlanType::DropView(v) => { - struct_ser.serialize_field("dropView", v)?; - } - } + if self.right.is_some() { + len += 1; + } + if self.join_type != 0 { + len += 1; + } + if self.filter.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.NestedLoopJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if self.join_type != 0 { + let v = JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for LogicalPlanNode { +impl<'de> serde::Deserialize<'de> for NestedLoopJoinExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "listing_scan", - "listingScan", - "projection", - "selection", - "limit", - "aggregate", - "join", - "sort", - "repartition", - "empty_relation", - "emptyRelation", - "create_external_table", - "createExternalTable", - "explain", - "window", - "analyze", - "cross_join", - "crossJoin", - "values", - "extension", - "create_catalog_schema", - "createCatalogSchema", - "union", - "create_catalog", - "createCatalog", - "subquery_alias", - "subqueryAlias", - "create_view", - "createView", - "distinct", - "view_scan", - "viewScan", - "custom_scan", - "customScan", - "prepare", - "drop_view", - "dropView", + { + const FIELDS: &[&str] = &[ + "left", + "right", + "join_type", + "joinType", + "filter", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - ListingScan, - Projection, - Selection, - Limit, - Aggregate, - Join, - Sort, - Repartition, - EmptyRelation, - CreateExternalTable, - Explain, - Window, - Analyze, - CrossJoin, - Values, - Extension, - CreateCatalogSchema, - Union, - CreateCatalog, - SubqueryAlias, - CreateView, - Distinct, - ViewScan, - CustomScan, - Prepare, - DropView, + Left, + Right, + JoinType, + Filter, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -12197,32 +15009,10 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { E: serde::de::Error, { match value { - "listingScan" | "listing_scan" => Ok(GeneratedField::ListingScan), - "projection" => Ok(GeneratedField::Projection), - "selection" => Ok(GeneratedField::Selection), - "limit" => Ok(GeneratedField::Limit), - "aggregate" => Ok(GeneratedField::Aggregate), - "join" => Ok(GeneratedField::Join), - "sort" => Ok(GeneratedField::Sort), - "repartition" => Ok(GeneratedField::Repartition), - "emptyRelation" | "empty_relation" => Ok(GeneratedField::EmptyRelation), - "createExternalTable" | "create_external_table" => Ok(GeneratedField::CreateExternalTable), - "explain" => Ok(GeneratedField::Explain), - "window" => Ok(GeneratedField::Window), - "analyze" => Ok(GeneratedField::Analyze), - "crossJoin" | "cross_join" => Ok(GeneratedField::CrossJoin), - "values" => Ok(GeneratedField::Values), - "extension" => Ok(GeneratedField::Extension), - "createCatalogSchema" | "create_catalog_schema" => Ok(GeneratedField::CreateCatalogSchema), - "union" => Ok(GeneratedField::Union), - "createCatalog" | "create_catalog" => Ok(GeneratedField::CreateCatalog), - "subqueryAlias" | "subquery_alias" => Ok(GeneratedField::SubqueryAlias), - "createView" | "create_view" => Ok(GeneratedField::CreateView), - "distinct" => Ok(GeneratedField::Distinct), - "viewScan" | "view_scan" => Ok(GeneratedField::ViewScan), - "customScan" | "custom_scan" => Ok(GeneratedField::CustomScan), - "prepare" => Ok(GeneratedField::Prepare), - "dropView" | "drop_view" => Ok(GeneratedField::DropView), + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "filter" => Ok(GeneratedField::Filter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -12232,212 +15022,60 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LogicalPlanNode; + type Value = NestedLoopJoinExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LogicalPlanNode") + formatter.write_str("struct datafusion.NestedLoopJoinExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut logical_plan_type__ = None; + let mut left__ = None; + let mut right__ = None; + let mut join_type__ = None; + let mut filter__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::ListingScan => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("listingScan")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::ListingScan) -; - } - GeneratedField::Projection => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("projection")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Projection) -; - } - GeneratedField::Selection => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("selection")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Selection) -; - } - GeneratedField::Limit => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("limit")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Limit) -; - } - GeneratedField::Aggregate => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("aggregate")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Aggregate) -; - } - GeneratedField::Join => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("join")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Join) -; - } - GeneratedField::Sort => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("sort")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Sort) -; - } - GeneratedField::Repartition => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("repartition")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Repartition) -; - } - GeneratedField::EmptyRelation => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("emptyRelation")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::EmptyRelation) -; - } - GeneratedField::CreateExternalTable => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("createExternalTable")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateExternalTable) -; - } - GeneratedField::Explain => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("explain")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Explain) -; - } - GeneratedField::Window => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("window")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Window) -; - } - GeneratedField::Analyze => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("analyze")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Analyze) -; - } - GeneratedField::CrossJoin => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("crossJoin")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CrossJoin) -; - } - GeneratedField::Values => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("values")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Values) -; - } - GeneratedField::Extension => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("extension")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Extension) -; - } - GeneratedField::CreateCatalogSchema => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("createCatalogSchema")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateCatalogSchema) -; - } - GeneratedField::Union => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("union")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Union) -; - } - GeneratedField::CreateCatalog => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("createCatalog")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateCatalog) -; - } - GeneratedField::SubqueryAlias => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("subqueryAlias")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::SubqueryAlias) -; - } - GeneratedField::CreateView => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("createView")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateView) -; - } - GeneratedField::Distinct => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("distinct")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Distinct) -; - } - GeneratedField::ViewScan => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("viewScan")); + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::ViewScan) -; + left__ = map_.next_value()?; } - GeneratedField::CustomScan => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("customScan")); + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CustomScan) -; + right__ = map_.next_value()?; } - GeneratedField::Prepare => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("prepare")); + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Prepare) -; + join_type__ = Some(map_.next_value::()? as i32); } - GeneratedField::DropView => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("dropView")); + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DropView) -; + filter__ = map_.next_value()?; } } } - Ok(LogicalPlanNode { - logical_plan_type: logical_plan_type__, + Ok(NestedLoopJoinExecNode { + left: left__, + right: right__, + join_type: join_type__.unwrap_or_default(), + filter: filter__, }) } } - deserializer.deserialize_struct("datafusion.LogicalPlanNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.NestedLoopJoinExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Map { +impl serde::Serialize for Not { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -12445,39 +15083,29 @@ impl serde::Serialize for Map { { use serde::ser::SerializeStruct; let mut len = 0; - if self.field_type.is_some() { - len += 1; - } - if self.keys_sorted { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Map", len)?; - if let Some(v) = self.field_type.as_ref() { - struct_ser.serialize_field("fieldType", v)?; - } - if self.keys_sorted { - struct_ser.serialize_field("keysSorted", &self.keys_sorted)?; + let mut struct_ser = serializer.serialize_struct("datafusion.Not", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Map { +impl<'de> serde::Deserialize<'de> for Not { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "field_type", - "fieldType", - "keys_sorted", - "keysSorted", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - FieldType, - KeysSorted, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -12499,8 +15127,7 @@ impl<'de> serde::Deserialize<'de> for Map { E: serde::de::Error, { match value { - "fieldType" | "field_type" => Ok(GeneratedField::FieldType), - "keysSorted" | "keys_sorted" => Ok(GeneratedField::KeysSorted), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -12510,44 +15137,128 @@ impl<'de> serde::Deserialize<'de> for Map { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Map; + type Value = Not; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Map") + formatter.write_str("struct datafusion.Not") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut field_type__ = None; - let mut keys_sorted__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::FieldType => { - if field_type__.is_some() { - return Err(serde::de::Error::duplicate_field("fieldType")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - field_type__ = map_.next_value()?; + expr__ = map_.next_value()?; } - GeneratedField::KeysSorted => { - if keys_sorted__.is_some() { - return Err(serde::de::Error::duplicate_field("keysSorted")); + } + } + Ok(Not { + expr: expr__, + }) + } + } + deserializer.deserialize_struct("datafusion.Not", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for OptimizedLogicalPlanType { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.optimizer_name.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.OptimizedLogicalPlanType", len)?; + if !self.optimizer_name.is_empty() { + struct_ser.serialize_field("optimizerName", &self.optimizer_name)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for OptimizedLogicalPlanType { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "optimizer_name", + "optimizerName", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + OptimizerName, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "optimizerName" | "optimizer_name" => Ok(GeneratedField::OptimizerName), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = OptimizedLogicalPlanType; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.OptimizedLogicalPlanType") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut optimizer_name__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::OptimizerName => { + if optimizer_name__.is_some() { + return Err(serde::de::Error::duplicate_field("optimizerName")); } - keys_sorted__ = Some(map_.next_value()?); + optimizer_name__ = Some(map_.next_value()?); } } } - Ok(Map { - field_type: field_type__, - keys_sorted: keys_sorted__.unwrap_or_default(), + Ok(OptimizedLogicalPlanType { + optimizer_name: optimizer_name__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.Map", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.OptimizedLogicalPlanType", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for MaybeFilter { +impl serde::Serialize for OptimizedPhysicalPlanType { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -12555,29 +15266,30 @@ impl serde::Serialize for MaybeFilter { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.optimizer_name.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.MaybeFilter", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.OptimizedPhysicalPlanType", len)?; + if !self.optimizer_name.is_empty() { + struct_ser.serialize_field("optimizerName", &self.optimizer_name)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for MaybeFilter { +impl<'de> serde::Deserialize<'de> for OptimizedPhysicalPlanType { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "optimizer_name", + "optimizerName", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + OptimizerName, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -12599,7 +15311,7 @@ impl<'de> serde::Deserialize<'de> for MaybeFilter { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "optimizerName" | "optimizer_name" => Ok(GeneratedField::OptimizerName), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -12609,36 +15321,36 @@ impl<'de> serde::Deserialize<'de> for MaybeFilter { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = MaybeFilter; + type Value = OptimizedPhysicalPlanType; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.MaybeFilter") + formatter.write_str("struct datafusion.OptimizedPhysicalPlanType") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut optimizer_name__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::OptimizerName => { + if optimizer_name__.is_some() { + return Err(serde::de::Error::duplicate_field("optimizerName")); } - expr__ = map_.next_value()?; + optimizer_name__ = Some(map_.next_value()?); } } } - Ok(MaybeFilter { - expr: expr__, + Ok(OptimizedPhysicalPlanType { + optimizer_name: optimizer_name__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.MaybeFilter", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.OptimizedPhysicalPlanType", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for MaybePhysicalSortExprs { +impl serde::Serialize for OwnedTableReference { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -12646,30 +15358,43 @@ impl serde::Serialize for MaybePhysicalSortExprs { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.sort_expr.is_empty() { + if self.table_reference_enum.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.MaybePhysicalSortExprs", len)?; - if !self.sort_expr.is_empty() { - struct_ser.serialize_field("sortExpr", &self.sort_expr)?; + let mut struct_ser = serializer.serialize_struct("datafusion.OwnedTableReference", len)?; + if let Some(v) = self.table_reference_enum.as_ref() { + match v { + owned_table_reference::TableReferenceEnum::Bare(v) => { + struct_ser.serialize_field("bare", v)?; + } + owned_table_reference::TableReferenceEnum::Partial(v) => { + struct_ser.serialize_field("partial", v)?; + } + owned_table_reference::TableReferenceEnum::Full(v) => { + struct_ser.serialize_field("full", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for MaybePhysicalSortExprs { +impl<'de> serde::Deserialize<'de> for OwnedTableReference { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "sort_expr", - "sortExpr", + "bare", + "partial", + "full", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - SortExpr, + Bare, + Partial, + Full, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -12691,7 +15416,9 @@ impl<'de> serde::Deserialize<'de> for MaybePhysicalSortExprs { E: serde::de::Error, { match value { - "sortExpr" | "sort_expr" => Ok(GeneratedField::SortExpr), + "bare" => Ok(GeneratedField::Bare), + "partial" => Ok(GeneratedField::Partial), + "full" => Ok(GeneratedField::Full), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -12701,66 +15428,73 @@ impl<'de> serde::Deserialize<'de> for MaybePhysicalSortExprs { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = MaybePhysicalSortExprs; + type Value = OwnedTableReference; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.MaybePhysicalSortExprs") + formatter.write_str("struct datafusion.OwnedTableReference") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut sort_expr__ = None; + let mut table_reference_enum__ = None; while let Some(k) = map_.next_key()? { - match k { - GeneratedField::SortExpr => { - if sort_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("sortExpr")); + match k { + GeneratedField::Bare => { + if table_reference_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("bare")); } - sort_expr__ = Some(map_.next_value()?); + table_reference_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(owned_table_reference::TableReferenceEnum::Bare) +; + } + GeneratedField::Partial => { + if table_reference_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("partial")); + } + table_reference_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(owned_table_reference::TableReferenceEnum::Partial) +; + } + GeneratedField::Full => { + if table_reference_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("full")); + } + table_reference_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(owned_table_reference::TableReferenceEnum::Full) +; } } } - Ok(MaybePhysicalSortExprs { - sort_expr: sort_expr__.unwrap_or_default(), + Ok(OwnedTableReference { + table_reference_enum: table_reference_enum__, }) } } - deserializer.deserialize_struct("datafusion.MaybePhysicalSortExprs", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.OwnedTableReference", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for NamedStructField { +impl serde::Serialize for ParquetFormat { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { use serde::ser::SerializeStruct; - let mut len = 0; - if self.name.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.NamedStructField", len)?; - if let Some(v) = self.name.as_ref() { - struct_ser.serialize_field("name", v)?; - } + let len = 0; + let struct_ser = serializer.serialize_struct("datafusion.ParquetFormat", len)?; struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for NamedStructField { +impl<'de> serde::Deserialize<'de> for ParquetFormat { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -12781,10 +15515,7 @@ impl<'de> serde::Deserialize<'de> for NamedStructField { where E: serde::de::Error, { - match value { - "name" => Ok(GeneratedField::Name), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } + Err(serde::de::Error::unknown_field(value, FIELDS)) } } deserializer.deserialize_identifier(GeneratedVisitor) @@ -12792,36 +15523,27 @@ impl<'de> serde::Deserialize<'de> for NamedStructField { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = NamedStructField; + type Value = ParquetFormat; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.NamedStructField") + formatter.write_str("struct datafusion.ParquetFormat") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); - } - name__ = map_.next_value()?; - } - } + while map_.next_key::()?.is_some() { + let _ = map_.next_value::()?; } - Ok(NamedStructField { - name: name__, + Ok(ParquetFormat { }) } } - deserializer.deserialize_struct("datafusion.NamedStructField", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ParquetFormat", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for NamedStructFieldExpr { +impl serde::Serialize for ParquetScanExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -12829,29 +15551,38 @@ impl serde::Serialize for NamedStructFieldExpr { { use serde::ser::SerializeStruct; let mut len = 0; - if self.name.is_some() { + if self.base_conf.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.NamedStructFieldExpr", len)?; - if let Some(v) = self.name.as_ref() { - struct_ser.serialize_field("name", v)?; + if self.predicate.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetScanExecNode", len)?; + if let Some(v) = self.base_conf.as_ref() { + struct_ser.serialize_field("baseConf", v)?; + } + if let Some(v) = self.predicate.as_ref() { + struct_ser.serialize_field("predicate", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for NamedStructFieldExpr { +impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", + "base_conf", + "baseConf", + "predicate", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, + BaseConf, + Predicate, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -12873,7 +15604,8 @@ impl<'de> serde::Deserialize<'de> for NamedStructFieldExpr { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), + "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), + "predicate" => Ok(GeneratedField::Predicate), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -12883,36 +15615,44 @@ impl<'de> serde::Deserialize<'de> for NamedStructFieldExpr { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = NamedStructFieldExpr; + type Value = ParquetScanExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.NamedStructFieldExpr") + formatter.write_str("struct datafusion.ParquetScanExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; + let mut base_conf__ = None; + let mut predicate__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); + GeneratedField::BaseConf => { + if base_conf__.is_some() { + return Err(serde::de::Error::duplicate_field("baseConf")); } - name__ = map_.next_value()?; + base_conf__ = map_.next_value()?; + } + GeneratedField::Predicate => { + if predicate__.is_some() { + return Err(serde::de::Error::duplicate_field("predicate")); + } + predicate__ = map_.next_value()?; } } } - Ok(NamedStructFieldExpr { - name: name__, + Ok(ParquetScanExecNode { + base_conf: base_conf__, + predicate: predicate__, }) } } - deserializer.deserialize_struct("datafusion.NamedStructFieldExpr", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ParquetScanExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for NegativeNode { +impl serde::Serialize for ParquetSink { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -12920,29 +15660,29 @@ impl serde::Serialize for NegativeNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.config.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.NegativeNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for NegativeNode { +impl<'de> serde::Deserialize<'de> for ParquetSink { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "config", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Config, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -12964,7 +15704,7 @@ impl<'de> serde::Deserialize<'de> for NegativeNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "config" => Ok(GeneratedField::Config), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -12974,36 +15714,36 @@ impl<'de> serde::Deserialize<'de> for NegativeNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = NegativeNode; + type Value = ParquetSink; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.NegativeNode") + formatter.write_str("struct datafusion.ParquetSink") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut config__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); } - expr__ = map_.next_value()?; + config__ = map_.next_value()?; } } } - Ok(NegativeNode { - expr: expr__, + Ok(ParquetSink { + config: config__, }) } } - deserializer.deserialize_struct("datafusion.NegativeNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ParquetSink", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for NestedLoopJoinExecNode { +impl serde::Serialize for ParquetSinkExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -13011,56 +15751,55 @@ impl serde::Serialize for NestedLoopJoinExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.left.is_some() { + if self.input.is_some() { len += 1; } - if self.right.is_some() { + if self.sink.is_some() { len += 1; } - if self.join_type != 0 { + if self.sink_schema.is_some() { len += 1; } - if self.filter.is_some() { + if self.sort_order.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.NestedLoopJoinExecNode", len)?; - if let Some(v) = self.left.as_ref() { - struct_ser.serialize_field("left", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetSinkExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } - if let Some(v) = self.right.as_ref() { - struct_ser.serialize_field("right", v)?; + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; } - if self.join_type != 0 { - let v = JoinType::try_from(self.join_type) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; - struct_ser.serialize_field("joinType", &v)?; + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; } - if let Some(v) = self.filter.as_ref() { - struct_ser.serialize_field("filter", v)?; + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for NestedLoopJoinExecNode { +impl<'de> serde::Deserialize<'de> for ParquetSinkExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "left", - "right", - "join_type", - "joinType", - "filter", + "input", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Left, - Right, - JoinType, - Filter, + Input, + Sink, + SinkSchema, + SortOrder, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -13082,10 +15821,10 @@ impl<'de> serde::Deserialize<'de> for NestedLoopJoinExecNode { E: serde::de::Error, { match value { - "left" => Ok(GeneratedField::Left), - "right" => Ok(GeneratedField::Right), - "joinType" | "join_type" => Ok(GeneratedField::JoinType), - "filter" => Ok(GeneratedField::Filter), + "input" => Ok(GeneratedField::Input), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -13095,60 +15834,60 @@ impl<'de> serde::Deserialize<'de> for NestedLoopJoinExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = NestedLoopJoinExecNode; + type Value = ParquetSinkExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.NestedLoopJoinExecNode") + formatter.write_str("struct datafusion.ParquetSinkExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut left__ = None; - let mut right__ = None; - let mut join_type__ = None; - let mut filter__ = None; + let mut input__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Left => { - if left__.is_some() { - return Err(serde::de::Error::duplicate_field("left")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - left__ = map_.next_value()?; + input__ = map_.next_value()?; } - GeneratedField::Right => { - if right__.is_some() { - return Err(serde::de::Error::duplicate_field("right")); + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); } - right__ = map_.next_value()?; + sink__ = map_.next_value()?; } - GeneratedField::JoinType => { - if join_type__.is_some() { - return Err(serde::de::Error::duplicate_field("joinType")); + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); } - join_type__ = Some(map_.next_value::()? as i32); + sink_schema__ = map_.next_value()?; } - GeneratedField::Filter => { - if filter__.is_some() { - return Err(serde::de::Error::duplicate_field("filter")); + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); } - filter__ = map_.next_value()?; + sort_order__ = map_.next_value()?; } } } - Ok(NestedLoopJoinExecNode { - left: left__, - right: right__, - join_type: join_type__.unwrap_or_default(), - filter: filter__, + Ok(ParquetSinkExecNode { + input: input__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, }) } } - deserializer.deserialize_struct("datafusion.NestedLoopJoinExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ParquetSinkExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Not { +impl serde::Serialize for ParquetWriterOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -13156,29 +15895,30 @@ impl serde::Serialize for Not { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.writer_properties.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Not", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetWriterOptions", len)?; + if let Some(v) = self.writer_properties.as_ref() { + struct_ser.serialize_field("writerProperties", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Not { +impl<'de> serde::Deserialize<'de> for ParquetWriterOptions { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "writer_properties", + "writerProperties", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + WriterProperties, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -13200,7 +15940,7 @@ impl<'de> serde::Deserialize<'de> for Not { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "writerProperties" | "writer_properties" => Ok(GeneratedField::WriterProperties), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -13210,36 +15950,36 @@ impl<'de> serde::Deserialize<'de> for Not { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Not; + type Value = ParquetWriterOptions; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Not") + formatter.write_str("struct datafusion.ParquetWriterOptions") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut writer_properties__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::WriterProperties => { + if writer_properties__.is_some() { + return Err(serde::de::Error::duplicate_field("writerProperties")); } - expr__ = map_.next_value()?; + writer_properties__ = map_.next_value()?; } } } - Ok(Not { - expr: expr__, + Ok(ParquetWriterOptions { + writer_properties: writer_properties__, }) } } - deserializer.deserialize_struct("datafusion.Not", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ParquetWriterOptions", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for OptimizedLogicalPlanType { +impl serde::Serialize for PartialTableReference { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -13247,30 +15987,37 @@ impl serde::Serialize for OptimizedLogicalPlanType { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.optimizer_name.is_empty() { + if !self.schema.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.OptimizedLogicalPlanType", len)?; - if !self.optimizer_name.is_empty() { - struct_ser.serialize_field("optimizerName", &self.optimizer_name)?; + if !self.table.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PartialTableReference", len)?; + if !self.schema.is_empty() { + struct_ser.serialize_field("schema", &self.schema)?; + } + if !self.table.is_empty() { + struct_ser.serialize_field("table", &self.table)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for OptimizedLogicalPlanType { +impl<'de> serde::Deserialize<'de> for PartialTableReference { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "optimizer_name", - "optimizerName", + "schema", + "table", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - OptimizerName, + Schema, + Table, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -13292,7 +16039,8 @@ impl<'de> serde::Deserialize<'de> for OptimizedLogicalPlanType { E: serde::de::Error, { match value { - "optimizerName" | "optimizer_name" => Ok(GeneratedField::OptimizerName), + "schema" => Ok(GeneratedField::Schema), + "table" => Ok(GeneratedField::Table), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -13302,36 +16050,44 @@ impl<'de> serde::Deserialize<'de> for OptimizedLogicalPlanType { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = OptimizedLogicalPlanType; + type Value = PartialTableReference; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.OptimizedLogicalPlanType") + formatter.write_str("struct datafusion.PartialTableReference") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut optimizer_name__ = None; + let mut schema__ = None; + let mut table__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::OptimizerName => { - if optimizer_name__.is_some() { - return Err(serde::de::Error::duplicate_field("optimizerName")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - optimizer_name__ = Some(map_.next_value()?); + schema__ = Some(map_.next_value()?); + } + GeneratedField::Table => { + if table__.is_some() { + return Err(serde::de::Error::duplicate_field("table")); + } + table__ = Some(map_.next_value()?); } } } - Ok(OptimizedLogicalPlanType { - optimizer_name: optimizer_name__.unwrap_or_default(), + Ok(PartialTableReference { + schema: schema__.unwrap_or_default(), + table: table__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.OptimizedLogicalPlanType", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PartialTableReference", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for OptimizedPhysicalPlanType { +impl serde::Serialize for PartiallySortedInputOrderMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -13339,30 +16095,29 @@ impl serde::Serialize for OptimizedPhysicalPlanType { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.optimizer_name.is_empty() { + if !self.columns.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.OptimizedPhysicalPlanType", len)?; - if !self.optimizer_name.is_empty() { - struct_ser.serialize_field("optimizerName", &self.optimizer_name)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PartiallySortedInputOrderMode", len)?; + if !self.columns.is_empty() { + struct_ser.serialize_field("columns", &self.columns.iter().map(ToString::to_string).collect::>())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for OptimizedPhysicalPlanType { +impl<'de> serde::Deserialize<'de> for PartiallySortedInputOrderMode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "optimizer_name", - "optimizerName", + "columns", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - OptimizerName, + Columns, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -13384,7 +16139,7 @@ impl<'de> serde::Deserialize<'de> for OptimizedPhysicalPlanType { E: serde::de::Error, { match value { - "optimizerName" | "optimizer_name" => Ok(GeneratedField::OptimizerName), + "columns" => Ok(GeneratedField::Columns), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -13394,36 +16149,39 @@ impl<'de> serde::Deserialize<'de> for OptimizedPhysicalPlanType { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = OptimizedPhysicalPlanType; + type Value = PartiallySortedInputOrderMode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.OptimizedPhysicalPlanType") + formatter.write_str("struct datafusion.PartiallySortedInputOrderMode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut optimizer_name__ = None; + let mut columns__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::OptimizerName => { - if optimizer_name__.is_some() { - return Err(serde::de::Error::duplicate_field("optimizerName")); + GeneratedField::Columns => { + if columns__.is_some() { + return Err(serde::de::Error::duplicate_field("columns")); } - optimizer_name__ = Some(map_.next_value()?); + columns__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; } } } - Ok(OptimizedPhysicalPlanType { - optimizer_name: optimizer_name__.unwrap_or_default(), + Ok(PartiallySortedInputOrderMode { + columns: columns__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.OptimizedPhysicalPlanType", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PartiallySortedInputOrderMode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for OwnedTableReference { +impl serde::Serialize for PartitionColumn { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -13431,43 +16189,38 @@ impl serde::Serialize for OwnedTableReference { { use serde::ser::SerializeStruct; let mut len = 0; - if self.table_reference_enum.is_some() { + if !self.name.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.OwnedTableReference", len)?; - if let Some(v) = self.table_reference_enum.as_ref() { - match v { - owned_table_reference::TableReferenceEnum::Bare(v) => { - struct_ser.serialize_field("bare", v)?; - } - owned_table_reference::TableReferenceEnum::Partial(v) => { - struct_ser.serialize_field("partial", v)?; - } - owned_table_reference::TableReferenceEnum::Full(v) => { - struct_ser.serialize_field("full", v)?; - } - } + if self.arrow_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PartitionColumn", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if let Some(v) = self.arrow_type.as_ref() { + struct_ser.serialize_field("arrowType", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for OwnedTableReference { +impl<'de> serde::Deserialize<'de> for PartitionColumn { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "bare", - "partial", - "full", + "name", + "arrow_type", + "arrowType", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Bare, - Partial, - Full, + Name, + ArrowType, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -13489,9 +16242,8 @@ impl<'de> serde::Deserialize<'de> for OwnedTableReference { E: serde::de::Error, { match value { - "bare" => Ok(GeneratedField::Bare), - "partial" => Ok(GeneratedField::Partial), - "full" => Ok(GeneratedField::Full), + "name" => Ok(GeneratedField::Name), + "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -13501,122 +16253,118 @@ impl<'de> serde::Deserialize<'de> for OwnedTableReference { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = OwnedTableReference; + type Value = PartitionColumn; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.OwnedTableReference") + formatter.write_str("struct datafusion.PartitionColumn") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut table_reference_enum__ = None; + let mut name__ = None; + let mut arrow_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Bare => { - if table_reference_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("bare")); - } - table_reference_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(owned_table_reference::TableReferenceEnum::Bare) -; - } - GeneratedField::Partial => { - if table_reference_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("partial")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - table_reference_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(owned_table_reference::TableReferenceEnum::Partial) -; + name__ = Some(map_.next_value()?); } - GeneratedField::Full => { - if table_reference_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("full")); + GeneratedField::ArrowType => { + if arrow_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowType")); } - table_reference_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(owned_table_reference::TableReferenceEnum::Full) -; + arrow_type__ = map_.next_value()?; } } } - Ok(OwnedTableReference { - table_reference_enum: table_reference_enum__, + Ok(PartitionColumn { + name: name__.unwrap_or_default(), + arrow_type: arrow_type__, }) } } - deserializer.deserialize_struct("datafusion.OwnedTableReference", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PartitionColumn", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ParquetFormat { +impl serde::Serialize for PartitionMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { - use serde::ser::SerializeStruct; - let len = 0; - let struct_ser = serializer.serialize_struct("datafusion.ParquetFormat", len)?; - struct_ser.end() + let variant = match self { + Self::CollectLeft => "COLLECT_LEFT", + Self::Partitioned => "PARTITIONED", + Self::Auto => "AUTO", + }; + serializer.serialize_str(variant) } } -impl<'de> serde::Deserialize<'de> for ParquetFormat { +impl<'de> serde::Deserialize<'de> for PartitionMode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "COLLECT_LEFT", + "PARTITIONED", + "AUTO", ]; - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; + struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PartitionMode; - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - Err(serde::de::Error::unknown_field(value, FIELDS)) - } - } - deserializer.deserialize_identifier(GeneratedVisitor) + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ParquetFormat; - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ParquetFormat") + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) } - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, { - while map_.next_key::()?.is_some() { - let _ = map_.next_value::()?; + match value { + "COLLECT_LEFT" => Ok(PartitionMode::CollectLeft), + "PARTITIONED" => Ok(PartitionMode::Partitioned), + "AUTO" => Ok(PartitionMode::Auto), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } - Ok(ParquetFormat { - }) } } - deserializer.deserialize_struct("datafusion.ParquetFormat", FIELDS, GeneratedVisitor) + deserializer.deserialize_any(GeneratedVisitor) } } -impl serde::Serialize for ParquetScanExecNode { +impl serde::Serialize for PartitionStats { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -13624,38 +16372,60 @@ impl serde::Serialize for ParquetScanExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.base_conf.is_some() { + if self.num_rows != 0 { len += 1; } - if self.predicate.is_some() { + if self.num_batches != 0 { + len += 1; + } + if self.num_bytes != 0 { + len += 1; + } + if !self.column_stats.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ParquetScanExecNode", len)?; - if let Some(v) = self.base_conf.as_ref() { - struct_ser.serialize_field("baseConf", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PartitionStats", len)?; + if self.num_rows != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("numRows", ToString::to_string(&self.num_rows).as_str())?; } - if let Some(v) = self.predicate.as_ref() { - struct_ser.serialize_field("predicate", v)?; + if self.num_batches != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("numBatches", ToString::to_string(&self.num_batches).as_str())?; + } + if self.num_bytes != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("numBytes", ToString::to_string(&self.num_bytes).as_str())?; + } + if !self.column_stats.is_empty() { + struct_ser.serialize_field("columnStats", &self.column_stats)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { +impl<'de> serde::Deserialize<'de> for PartitionStats { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "base_conf", - "baseConf", - "predicate", + "num_rows", + "numRows", + "num_batches", + "numBatches", + "num_bytes", + "numBytes", + "column_stats", + "columnStats", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - BaseConf, - Predicate, + NumRows, + NumBatches, + NumBytes, + ColumnStats, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -13677,8 +16447,10 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { E: serde::de::Error, { match value { - "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), - "predicate" => Ok(GeneratedField::Predicate), + "numRows" | "num_rows" => Ok(GeneratedField::NumRows), + "numBatches" | "num_batches" => Ok(GeneratedField::NumBatches), + "numBytes" | "num_bytes" => Ok(GeneratedField::NumBytes), + "columnStats" | "column_stats" => Ok(GeneratedField::ColumnStats), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -13688,44 +16460,66 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ParquetScanExecNode; + type Value = PartitionStats; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ParquetScanExecNode") + formatter.write_str("struct datafusion.PartitionStats") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut base_conf__ = None; - let mut predicate__ = None; + let mut num_rows__ = None; + let mut num_batches__ = None; + let mut num_bytes__ = None; + let mut column_stats__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::BaseConf => { - if base_conf__.is_some() { - return Err(serde::de::Error::duplicate_field("baseConf")); + GeneratedField::NumRows => { + if num_rows__.is_some() { + return Err(serde::de::Error::duplicate_field("numRows")); } - base_conf__ = map_.next_value()?; + num_rows__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } - GeneratedField::Predicate => { - if predicate__.is_some() { - return Err(serde::de::Error::duplicate_field("predicate")); + GeneratedField::NumBatches => { + if num_batches__.is_some() { + return Err(serde::de::Error::duplicate_field("numBatches")); } - predicate__ = map_.next_value()?; + num_batches__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::NumBytes => { + if num_bytes__.is_some() { + return Err(serde::de::Error::duplicate_field("numBytes")); + } + num_bytes__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::ColumnStats => { + if column_stats__.is_some() { + return Err(serde::de::Error::duplicate_field("columnStats")); + } + column_stats__ = Some(map_.next_value()?); } } } - Ok(ParquetScanExecNode { - base_conf: base_conf__, - predicate: predicate__, + Ok(PartitionStats { + num_rows: num_rows__.unwrap_or_default(), + num_batches: num_batches__.unwrap_or_default(), + num_bytes: num_bytes__.unwrap_or_default(), + column_stats: column_stats__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.ParquetScanExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PartitionStats", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PartialTableReference { +impl serde::Serialize for PartitionedFile { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -13733,37 +16527,65 @@ impl serde::Serialize for PartialTableReference { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.schema.is_empty() { + if !self.path.is_empty() { len += 1; } - if !self.table.is_empty() { + if self.size != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PartialTableReference", len)?; - if !self.schema.is_empty() { - struct_ser.serialize_field("schema", &self.schema)?; + if self.last_modified_ns != 0 { + len += 1; } - if !self.table.is_empty() { - struct_ser.serialize_field("table", &self.table)?; + if !self.partition_values.is_empty() { + len += 1; + } + if self.range.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PartitionedFile", len)?; + if !self.path.is_empty() { + struct_ser.serialize_field("path", &self.path)?; + } + if self.size != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("size", ToString::to_string(&self.size).as_str())?; + } + if self.last_modified_ns != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("lastModifiedNs", ToString::to_string(&self.last_modified_ns).as_str())?; + } + if !self.partition_values.is_empty() { + struct_ser.serialize_field("partitionValues", &self.partition_values)?; + } + if let Some(v) = self.range.as_ref() { + struct_ser.serialize_field("range", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PartialTableReference { +impl<'de> serde::Deserialize<'de> for PartitionedFile { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "schema", - "table", + "path", + "size", + "last_modified_ns", + "lastModifiedNs", + "partition_values", + "partitionValues", + "range", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Schema, - Table, + Path, + Size, + LastModifiedNs, + PartitionValues, + Range, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -13785,8 +16607,11 @@ impl<'de> serde::Deserialize<'de> for PartialTableReference { E: serde::de::Error, { match value { - "schema" => Ok(GeneratedField::Schema), - "table" => Ok(GeneratedField::Table), + "path" => Ok(GeneratedField::Path), + "size" => Ok(GeneratedField::Size), + "lastModifiedNs" | "last_modified_ns" => Ok(GeneratedField::LastModifiedNs), + "partitionValues" | "partition_values" => Ok(GeneratedField::PartitionValues), + "range" => Ok(GeneratedField::Range), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -13796,44 +16621,72 @@ impl<'de> serde::Deserialize<'de> for PartialTableReference { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartialTableReference; + type Value = PartitionedFile; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PartialTableReference") + formatter.write_str("struct datafusion.PartitionedFile") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut schema__ = None; - let mut table__ = None; + let mut path__ = None; + let mut size__ = None; + let mut last_modified_ns__ = None; + let mut partition_values__ = None; + let mut range__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Path => { + if path__.is_some() { + return Err(serde::de::Error::duplicate_field("path")); } - schema__ = Some(map_.next_value()?); + path__ = Some(map_.next_value()?); } - GeneratedField::Table => { - if table__.is_some() { - return Err(serde::de::Error::duplicate_field("table")); + GeneratedField::Size => { + if size__.is_some() { + return Err(serde::de::Error::duplicate_field("size")); } - table__ = Some(map_.next_value()?); + size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::LastModifiedNs => { + if last_modified_ns__.is_some() { + return Err(serde::de::Error::duplicate_field("lastModifiedNs")); + } + last_modified_ns__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::PartitionValues => { + if partition_values__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionValues")); + } + partition_values__ = Some(map_.next_value()?); + } + GeneratedField::Range => { + if range__.is_some() { + return Err(serde::de::Error::duplicate_field("range")); + } + range__ = map_.next_value()?; } } } - Ok(PartialTableReference { - schema: schema__.unwrap_or_default(), - table: table__.unwrap_or_default(), + Ok(PartitionedFile { + path: path__.unwrap_or_default(), + size: size__.unwrap_or_default(), + last_modified_ns: last_modified_ns__.unwrap_or_default(), + partition_values: partition_values__.unwrap_or_default(), + range: range__, }) } } - deserializer.deserialize_struct("datafusion.PartialTableReference", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PartitionedFile", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PartiallySortedPartitionSearchMode { +impl serde::Serialize for PhysicalAggregateExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -13841,29 +16694,67 @@ impl serde::Serialize for PartiallySortedPartitionSearchMode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.columns.is_empty() { + if !self.expr.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PartiallySortedPartitionSearchMode", len)?; - if !self.columns.is_empty() { - struct_ser.serialize_field("columns", &self.columns.iter().map(ToString::to_string).collect::>())?; + if !self.ordering_req.is_empty() { + len += 1; + } + if self.distinct { + len += 1; + } + if self.aggregate_function.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalAggregateExprNode", len)?; + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; + } + if !self.ordering_req.is_empty() { + struct_ser.serialize_field("orderingReq", &self.ordering_req)?; + } + if self.distinct { + struct_ser.serialize_field("distinct", &self.distinct)?; + } + if let Some(v) = self.aggregate_function.as_ref() { + match v { + physical_aggregate_expr_node::AggregateFunction::AggrFunction(v) => { + let v = AggregateFunction::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + struct_ser.serialize_field("aggrFunction", &v)?; + } + physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(v) => { + struct_ser.serialize_field("userDefinedAggrFunction", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { +impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "columns", + "expr", + "ordering_req", + "orderingReq", + "distinct", + "aggr_function", + "aggrFunction", + "user_defined_aggr_function", + "userDefinedAggrFunction", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Columns, + Expr, + OrderingReq, + Distinct, + AggrFunction, + UserDefinedAggrFunction, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -13885,7 +16776,11 @@ impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { E: serde::de::Error, { match value { - "columns" => Ok(GeneratedField::Columns), + "expr" => Ok(GeneratedField::Expr), + "orderingReq" | "ordering_req" => Ok(GeneratedField::OrderingReq), + "distinct" => Ok(GeneratedField::Distinct), + "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), + "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -13895,113 +16790,66 @@ impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartiallySortedPartitionSearchMode; + type Value = PhysicalAggregateExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PartiallySortedPartitionSearchMode") + formatter.write_str("struct datafusion.PhysicalAggregateExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut columns__ = None; + let mut expr__ = None; + let mut ordering_req__ = None; + let mut distinct__ = None; + let mut aggregate_function__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Columns => { - if columns__.is_some() { - return Err(serde::de::Error::duplicate_field("columns")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - columns__ = - Some(map_.next_value::>>()? - .into_iter().map(|x| x.0).collect()) - ; + expr__ = Some(map_.next_value()?); + } + GeneratedField::OrderingReq => { + if ordering_req__.is_some() { + return Err(serde::de::Error::duplicate_field("orderingReq")); + } + ordering_req__ = Some(map_.next_value()?); + } + GeneratedField::Distinct => { + if distinct__.is_some() { + return Err(serde::de::Error::duplicate_field("distinct")); + } + distinct__ = Some(map_.next_value()?); + } + GeneratedField::AggrFunction => { + if aggregate_function__.is_some() { + return Err(serde::de::Error::duplicate_field("aggrFunction")); + } + aggregate_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_aggregate_expr_node::AggregateFunction::AggrFunction(x as i32)); + } + GeneratedField::UserDefinedAggrFunction => { + if aggregate_function__.is_some() { + return Err(serde::de::Error::duplicate_field("userDefinedAggrFunction")); + } + aggregate_function__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction); } } } - Ok(PartiallySortedPartitionSearchMode { - columns: columns__.unwrap_or_default(), + Ok(PhysicalAggregateExprNode { + expr: expr__.unwrap_or_default(), + ordering_req: ordering_req__.unwrap_or_default(), + distinct: distinct__.unwrap_or_default(), + aggregate_function: aggregate_function__, }) } } - deserializer.deserialize_struct("datafusion.PartiallySortedPartitionSearchMode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for PartitionMode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::CollectLeft => "COLLECT_LEFT", - Self::Partitioned => "PARTITIONED", - Self::Auto => "AUTO", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for PartitionMode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "COLLECT_LEFT", - "PARTITIONED", - "AUTO", - ]; - - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartitionMode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } - - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "COLLECT_LEFT" => Ok(PartitionMode::CollectLeft), - "PARTITIONED" => Ok(PartitionMode::Partitioned), - "AUTO" => Ok(PartitionMode::Auto), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), - } - } - } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalAggregateExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PartitionStats { +impl serde::Serialize for PhysicalAliasNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -14009,60 +16857,37 @@ impl serde::Serialize for PartitionStats { { use serde::ser::SerializeStruct; let mut len = 0; - if self.num_rows != 0 { - len += 1; - } - if self.num_batches != 0 { - len += 1; - } - if self.num_bytes != 0 { + if self.expr.is_some() { len += 1; } - if !self.column_stats.is_empty() { + if !self.alias.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PartitionStats", len)?; - if self.num_rows != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("numRows", ToString::to_string(&self.num_rows).as_str())?; - } - if self.num_batches != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("numBatches", ToString::to_string(&self.num_batches).as_str())?; - } - if self.num_bytes != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("numBytes", ToString::to_string(&self.num_bytes).as_str())?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalAliasNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } - if !self.column_stats.is_empty() { - struct_ser.serialize_field("columnStats", &self.column_stats)?; + if !self.alias.is_empty() { + struct_ser.serialize_field("alias", &self.alias)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PartitionStats { +impl<'de> serde::Deserialize<'de> for PhysicalAliasNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "num_rows", - "numRows", - "num_batches", - "numBatches", - "num_bytes", - "numBytes", - "column_stats", - "columnStats", + "expr", + "alias", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - NumRows, - NumBatches, - NumBytes, - ColumnStats, + Expr, + Alias, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -14084,10 +16909,8 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { E: serde::de::Error, { match value { - "numRows" | "num_rows" => Ok(GeneratedField::NumRows), - "numBatches" | "num_batches" => Ok(GeneratedField::NumBatches), - "numBytes" | "num_bytes" => Ok(GeneratedField::NumBytes), - "columnStats" | "column_stats" => Ok(GeneratedField::ColumnStats), + "expr" => Ok(GeneratedField::Expr), + "alias" => Ok(GeneratedField::Alias), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -14097,66 +16920,44 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartitionStats; + type Value = PhysicalAliasNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PartitionStats") + formatter.write_str("struct datafusion.PhysicalAliasNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut num_rows__ = None; - let mut num_batches__ = None; - let mut num_bytes__ = None; - let mut column_stats__ = None; + let mut expr__ = None; + let mut alias__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::NumRows => { - if num_rows__.is_some() { - return Err(serde::de::Error::duplicate_field("numRows")); - } - num_rows__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::NumBatches => { - if num_batches__.is_some() { - return Err(serde::de::Error::duplicate_field("numBatches")); - } - num_batches__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::NumBytes => { - if num_bytes__.is_some() { - return Err(serde::de::Error::duplicate_field("numBytes")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - num_bytes__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + expr__ = map_.next_value()?; } - GeneratedField::ColumnStats => { - if column_stats__.is_some() { - return Err(serde::de::Error::duplicate_field("columnStats")); + GeneratedField::Alias => { + if alias__.is_some() { + return Err(serde::de::Error::duplicate_field("alias")); } - column_stats__ = Some(map_.next_value()?); + alias__ = Some(map_.next_value()?); } - } - } - Ok(PartitionStats { - num_rows: num_rows__.unwrap_or_default(), - num_batches: num_batches__.unwrap_or_default(), - num_bytes: num_bytes__.unwrap_or_default(), - column_stats: column_stats__.unwrap_or_default(), + } + } + Ok(PhysicalAliasNode { + expr: expr__, + alias: alias__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PartitionStats", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalAliasNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PartitionedFile { +impl serde::Serialize for PhysicalBinaryExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -14164,65 +16965,45 @@ impl serde::Serialize for PartitionedFile { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.path.is_empty() { - len += 1; - } - if self.size != 0 { - len += 1; - } - if self.last_modified_ns != 0 { + if self.l.is_some() { len += 1; } - if !self.partition_values.is_empty() { + if self.r.is_some() { len += 1; } - if self.range.is_some() { + if !self.op.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PartitionedFile", len)?; - if !self.path.is_empty() { - struct_ser.serialize_field("path", &self.path)?; - } - if self.size != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("size", ToString::to_string(&self.size).as_str())?; - } - if self.last_modified_ns != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("lastModifiedNs", ToString::to_string(&self.last_modified_ns).as_str())?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalBinaryExprNode", len)?; + if let Some(v) = self.l.as_ref() { + struct_ser.serialize_field("l", v)?; } - if !self.partition_values.is_empty() { - struct_ser.serialize_field("partitionValues", &self.partition_values)?; + if let Some(v) = self.r.as_ref() { + struct_ser.serialize_field("r", v)?; } - if let Some(v) = self.range.as_ref() { - struct_ser.serialize_field("range", v)?; + if !self.op.is_empty() { + struct_ser.serialize_field("op", &self.op)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PartitionedFile { +impl<'de> serde::Deserialize<'de> for PhysicalBinaryExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "path", - "size", - "last_modified_ns", - "lastModifiedNs", - "partition_values", - "partitionValues", - "range", + "l", + "r", + "op", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Path, - Size, - LastModifiedNs, - PartitionValues, - Range, + L, + R, + Op, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -14244,11 +17025,9 @@ impl<'de> serde::Deserialize<'de> for PartitionedFile { E: serde::de::Error, { match value { - "path" => Ok(GeneratedField::Path), - "size" => Ok(GeneratedField::Size), - "lastModifiedNs" | "last_modified_ns" => Ok(GeneratedField::LastModifiedNs), - "partitionValues" | "partition_values" => Ok(GeneratedField::PartitionValues), - "range" => Ok(GeneratedField::Range), + "l" => Ok(GeneratedField::L), + "r" => Ok(GeneratedField::R), + "op" => Ok(GeneratedField::Op), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -14258,72 +17037,52 @@ impl<'de> serde::Deserialize<'de> for PartitionedFile { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartitionedFile; + type Value = PhysicalBinaryExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PartitionedFile") + formatter.write_str("struct datafusion.PhysicalBinaryExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut path__ = None; - let mut size__ = None; - let mut last_modified_ns__ = None; - let mut partition_values__ = None; - let mut range__ = None; + let mut l__ = None; + let mut r__ = None; + let mut op__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Path => { - if path__.is_some() { - return Err(serde::de::Error::duplicate_field("path")); - } - path__ = Some(map_.next_value()?); - } - GeneratedField::Size => { - if size__.is_some() { - return Err(serde::de::Error::duplicate_field("size")); - } - size__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::LastModifiedNs => { - if last_modified_ns__.is_some() { - return Err(serde::de::Error::duplicate_field("lastModifiedNs")); + GeneratedField::L => { + if l__.is_some() { + return Err(serde::de::Error::duplicate_field("l")); } - last_modified_ns__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + l__ = map_.next_value()?; } - GeneratedField::PartitionValues => { - if partition_values__.is_some() { - return Err(serde::de::Error::duplicate_field("partitionValues")); + GeneratedField::R => { + if r__.is_some() { + return Err(serde::de::Error::duplicate_field("r")); } - partition_values__ = Some(map_.next_value()?); + r__ = map_.next_value()?; } - GeneratedField::Range => { - if range__.is_some() { - return Err(serde::de::Error::duplicate_field("range")); + GeneratedField::Op => { + if op__.is_some() { + return Err(serde::de::Error::duplicate_field("op")); } - range__ = map_.next_value()?; + op__ = Some(map_.next_value()?); } } } - Ok(PartitionedFile { - path: path__.unwrap_or_default(), - size: size__.unwrap_or_default(), - last_modified_ns: last_modified_ns__.unwrap_or_default(), - partition_values: partition_values__.unwrap_or_default(), - range: range__, + Ok(PhysicalBinaryExprNode { + l: l__, + r: r__, + op: op__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PartitionedFile", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalBinaryExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalAggregateExprNode { +impl serde::Serialize for PhysicalCaseNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -14331,44 +17090,29 @@ impl serde::Serialize for PhysicalAggregateExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.expr.is_empty() { - len += 1; - } - if !self.ordering_req.is_empty() { + if self.expr.is_some() { len += 1; } - if self.distinct { + if !self.when_then_expr.is_empty() { len += 1; } - if self.aggregate_function.is_some() { + if self.else_expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalAggregateExprNode", len)?; - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; - } - if !self.ordering_req.is_empty() { - struct_ser.serialize_field("orderingReq", &self.ordering_req)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalCaseNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } - if self.distinct { - struct_ser.serialize_field("distinct", &self.distinct)?; + if !self.when_then_expr.is_empty() { + struct_ser.serialize_field("whenThenExpr", &self.when_then_expr)?; } - if let Some(v) = self.aggregate_function.as_ref() { - match v { - physical_aggregate_expr_node::AggregateFunction::AggrFunction(v) => { - let v = AggregateFunction::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("aggrFunction", &v)?; - } - physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(v) => { - struct_ser.serialize_field("userDefinedAggrFunction", v)?; - } - } + if let Some(v) = self.else_expr.as_ref() { + struct_ser.serialize_field("elseExpr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { +impl<'de> serde::Deserialize<'de> for PhysicalCaseNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -14376,22 +17120,17 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { { const FIELDS: &[&str] = &[ "expr", - "ordering_req", - "orderingReq", - "distinct", - "aggr_function", - "aggrFunction", - "user_defined_aggr_function", - "userDefinedAggrFunction", + "when_then_expr", + "whenThenExpr", + "else_expr", + "elseExpr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, - OrderingReq, - Distinct, - AggrFunction, - UserDefinedAggrFunction, + WhenThenExpr, + ElseExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -14414,10 +17153,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { { match value { "expr" => Ok(GeneratedField::Expr), - "orderingReq" | "ordering_req" => Ok(GeneratedField::OrderingReq), - "distinct" => Ok(GeneratedField::Distinct), - "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), - "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), + "whenThenExpr" | "when_then_expr" => Ok(GeneratedField::WhenThenExpr), + "elseExpr" | "else_expr" => Ok(GeneratedField::ElseExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -14427,66 +17164,52 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalAggregateExprNode; + type Value = PhysicalCaseNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalAggregateExprNode") + formatter.write_str("struct datafusion.PhysicalCaseNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - let mut ordering_req__ = None; - let mut distinct__ = None; - let mut aggregate_function__ = None; + let mut when_then_expr__ = None; + let mut else_expr__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map_.next_value()?); - } - GeneratedField::OrderingReq => { - if ordering_req__.is_some() { - return Err(serde::de::Error::duplicate_field("orderingReq")); - } - ordering_req__ = Some(map_.next_value()?); - } - GeneratedField::Distinct => { - if distinct__.is_some() { - return Err(serde::de::Error::duplicate_field("distinct")); - } - distinct__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } - GeneratedField::AggrFunction => { - if aggregate_function__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrFunction")); + GeneratedField::WhenThenExpr => { + if when_then_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("whenThenExpr")); } - aggregate_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_aggregate_expr_node::AggregateFunction::AggrFunction(x as i32)); + when_then_expr__ = Some(map_.next_value()?); } - GeneratedField::UserDefinedAggrFunction => { - if aggregate_function__.is_some() { - return Err(serde::de::Error::duplicate_field("userDefinedAggrFunction")); + GeneratedField::ElseExpr => { + if else_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("elseExpr")); } - aggregate_function__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction); + else_expr__ = map_.next_value()?; } } } - Ok(PhysicalAggregateExprNode { - expr: expr__.unwrap_or_default(), - ordering_req: ordering_req__.unwrap_or_default(), - distinct: distinct__.unwrap_or_default(), - aggregate_function: aggregate_function__, + Ok(PhysicalCaseNode { + expr: expr__, + when_then_expr: when_then_expr__.unwrap_or_default(), + else_expr: else_expr__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalAggregateExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalCaseNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalAliasNode { +impl serde::Serialize for PhysicalCastNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -14497,20 +17220,20 @@ impl serde::Serialize for PhysicalAliasNode { if self.expr.is_some() { len += 1; } - if !self.alias.is_empty() { + if self.arrow_type.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalAliasNode", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalCastNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } - if !self.alias.is_empty() { - struct_ser.serialize_field("alias", &self.alias)?; + if let Some(v) = self.arrow_type.as_ref() { + struct_ser.serialize_field("arrowType", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalAliasNode { +impl<'de> serde::Deserialize<'de> for PhysicalCastNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -14518,13 +17241,14 @@ impl<'de> serde::Deserialize<'de> for PhysicalAliasNode { { const FIELDS: &[&str] = &[ "expr", - "alias", + "arrow_type", + "arrowType", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, - Alias, + ArrowType, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -14547,7 +17271,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalAliasNode { { match value { "expr" => Ok(GeneratedField::Expr), - "alias" => Ok(GeneratedField::Alias), + "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -14557,18 +17281,18 @@ impl<'de> serde::Deserialize<'de> for PhysicalAliasNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalAliasNode; + type Value = PhysicalCastNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalAliasNode") + formatter.write_str("struct datafusion.PhysicalCastNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - let mut alias__ = None; + let mut arrow_type__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -14577,24 +17301,24 @@ impl<'de> serde::Deserialize<'de> for PhysicalAliasNode { } expr__ = map_.next_value()?; } - GeneratedField::Alias => { - if alias__.is_some() { - return Err(serde::de::Error::duplicate_field("alias")); + GeneratedField::ArrowType => { + if arrow_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowType")); } - alias__ = Some(map_.next_value()?); + arrow_type__ = map_.next_value()?; } } } - Ok(PhysicalAliasNode { + Ok(PhysicalCastNode { expr: expr__, - alias: alias__.unwrap_or_default(), + arrow_type: arrow_type__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalAliasNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalCastNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalBinaryExprNode { +impl serde::Serialize for PhysicalColumn { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -14602,45 +17326,37 @@ impl serde::Serialize for PhysicalBinaryExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.l.is_some() { - len += 1; - } - if self.r.is_some() { + if !self.name.is_empty() { len += 1; } - if !self.op.is_empty() { + if self.index != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalBinaryExprNode", len)?; - if let Some(v) = self.l.as_ref() { - struct_ser.serialize_field("l", v)?; - } - if let Some(v) = self.r.as_ref() { - struct_ser.serialize_field("r", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalColumn", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; } - if !self.op.is_empty() { - struct_ser.serialize_field("op", &self.op)?; + if self.index != 0 { + struct_ser.serialize_field("index", &self.index)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalBinaryExprNode { +impl<'de> serde::Deserialize<'de> for PhysicalColumn { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "l", - "r", - "op", + "name", + "index", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - L, - R, - Op, + Name, + Index, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -14662,9 +17378,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalBinaryExprNode { E: serde::de::Error, { match value { - "l" => Ok(GeneratedField::L), - "r" => Ok(GeneratedField::R), - "op" => Ok(GeneratedField::Op), + "name" => Ok(GeneratedField::Name), + "index" => Ok(GeneratedField::Index), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -14674,52 +17389,46 @@ impl<'de> serde::Deserialize<'de> for PhysicalBinaryExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalBinaryExprNode; + type Value = PhysicalColumn; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalBinaryExprNode") + formatter.write_str("struct datafusion.PhysicalColumn") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut l__ = None; - let mut r__ = None; - let mut op__ = None; + let mut name__ = None; + let mut index__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::L => { - if l__.is_some() { - return Err(serde::de::Error::duplicate_field("l")); - } - l__ = map_.next_value()?; - } - GeneratedField::R => { - if r__.is_some() { - return Err(serde::de::Error::duplicate_field("r")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - r__ = map_.next_value()?; + name__ = Some(map_.next_value()?); } - GeneratedField::Op => { - if op__.is_some() { - return Err(serde::de::Error::duplicate_field("op")); + GeneratedField::Index => { + if index__.is_some() { + return Err(serde::de::Error::duplicate_field("index")); } - op__ = Some(map_.next_value()?); + index__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(PhysicalBinaryExprNode { - l: l__, - r: r__, - op: op__.unwrap_or_default(), + Ok(PhysicalColumn { + name: name__.unwrap_or_default(), + index: index__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalBinaryExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalColumn", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalCaseNode { +impl serde::Serialize for PhysicalDateTimeIntervalExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -14727,47 +17436,45 @@ impl serde::Serialize for PhysicalCaseNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.l.is_some() { len += 1; } - if !self.when_then_expr.is_empty() { + if self.r.is_some() { len += 1; } - if self.else_expr.is_some() { + if !self.op.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalCaseNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalDateTimeIntervalExprNode", len)?; + if let Some(v) = self.l.as_ref() { + struct_ser.serialize_field("l", v)?; } - if !self.when_then_expr.is_empty() { - struct_ser.serialize_field("whenThenExpr", &self.when_then_expr)?; + if let Some(v) = self.r.as_ref() { + struct_ser.serialize_field("r", v)?; } - if let Some(v) = self.else_expr.as_ref() { - struct_ser.serialize_field("elseExpr", v)?; + if !self.op.is_empty() { + struct_ser.serialize_field("op", &self.op)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalCaseNode { +impl<'de> serde::Deserialize<'de> for PhysicalDateTimeIntervalExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "when_then_expr", - "whenThenExpr", - "else_expr", - "elseExpr", + "l", + "r", + "op", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - WhenThenExpr, - ElseExpr, + L, + R, + Op, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -14789,9 +17496,9 @@ impl<'de> serde::Deserialize<'de> for PhysicalCaseNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "whenThenExpr" | "when_then_expr" => Ok(GeneratedField::WhenThenExpr), - "elseExpr" | "else_expr" => Ok(GeneratedField::ElseExpr), + "l" => Ok(GeneratedField::L), + "r" => Ok(GeneratedField::R), + "op" => Ok(GeneratedField::Op), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -14801,91 +17508,184 @@ impl<'de> serde::Deserialize<'de> for PhysicalCaseNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalCaseNode; + type Value = PhysicalDateTimeIntervalExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalCaseNode") + formatter.write_str("struct datafusion.PhysicalDateTimeIntervalExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; - let mut when_then_expr__ = None; - let mut else_expr__ = None; + let mut l__ = None; + let mut r__ = None; + let mut op__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::L => { + if l__.is_some() { + return Err(serde::de::Error::duplicate_field("l")); } - expr__ = map_.next_value()?; + l__ = map_.next_value()?; } - GeneratedField::WhenThenExpr => { - if when_then_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("whenThenExpr")); + GeneratedField::R => { + if r__.is_some() { + return Err(serde::de::Error::duplicate_field("r")); } - when_then_expr__ = Some(map_.next_value()?); + r__ = map_.next_value()?; } - GeneratedField::ElseExpr => { - if else_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("elseExpr")); + GeneratedField::Op => { + if op__.is_some() { + return Err(serde::de::Error::duplicate_field("op")); } - else_expr__ = map_.next_value()?; + op__ = Some(map_.next_value()?); } } } - Ok(PhysicalCaseNode { - expr: expr__, - when_then_expr: when_then_expr__.unwrap_or_default(), - else_expr: else_expr__, - }) + Ok(PhysicalDateTimeIntervalExprNode { + l: l__, + r: r__, + op: op__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.PhysicalDateTimeIntervalExprNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for PhysicalExprNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.expr_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalExprNode", len)?; + if let Some(v) = self.expr_type.as_ref() { + match v { + physical_expr_node::ExprType::Column(v) => { + struct_ser.serialize_field("column", v)?; + } + physical_expr_node::ExprType::Literal(v) => { + struct_ser.serialize_field("literal", v)?; + } + physical_expr_node::ExprType::BinaryExpr(v) => { + struct_ser.serialize_field("binaryExpr", v)?; + } + physical_expr_node::ExprType::AggregateExpr(v) => { + struct_ser.serialize_field("aggregateExpr", v)?; + } + physical_expr_node::ExprType::IsNullExpr(v) => { + struct_ser.serialize_field("isNullExpr", v)?; + } + physical_expr_node::ExprType::IsNotNullExpr(v) => { + struct_ser.serialize_field("isNotNullExpr", v)?; + } + physical_expr_node::ExprType::NotExpr(v) => { + struct_ser.serialize_field("notExpr", v)?; + } + physical_expr_node::ExprType::Case(v) => { + struct_ser.serialize_field("case", v)?; + } + physical_expr_node::ExprType::Cast(v) => { + struct_ser.serialize_field("cast", v)?; + } + physical_expr_node::ExprType::Sort(v) => { + struct_ser.serialize_field("sort", v)?; + } + physical_expr_node::ExprType::Negative(v) => { + struct_ser.serialize_field("negative", v)?; + } + physical_expr_node::ExprType::InList(v) => { + struct_ser.serialize_field("inList", v)?; + } + physical_expr_node::ExprType::ScalarFunction(v) => { + struct_ser.serialize_field("scalarFunction", v)?; + } + physical_expr_node::ExprType::TryCast(v) => { + struct_ser.serialize_field("tryCast", v)?; + } + physical_expr_node::ExprType::WindowExpr(v) => { + struct_ser.serialize_field("windowExpr", v)?; + } + physical_expr_node::ExprType::ScalarUdf(v) => { + struct_ser.serialize_field("scalarUdf", v)?; + } + physical_expr_node::ExprType::LikeExpr(v) => { + struct_ser.serialize_field("likeExpr", v)?; + } + physical_expr_node::ExprType::GetIndexedFieldExpr(v) => { + struct_ser.serialize_field("getIndexedFieldExpr", v)?; + } } } - deserializer.deserialize_struct("datafusion.PhysicalCaseNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for PhysicalCastNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expr.is_some() { - len += 1; - } - if self.arrow_type.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalCastNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - if let Some(v) = self.arrow_type.as_ref() { - struct_ser.serialize_field("arrowType", v)?; - } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalCastNode { +impl<'de> serde::Deserialize<'de> for PhysicalExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "arrow_type", - "arrowType", + "column", + "literal", + "binary_expr", + "binaryExpr", + "aggregate_expr", + "aggregateExpr", + "is_null_expr", + "isNullExpr", + "is_not_null_expr", + "isNotNullExpr", + "not_expr", + "notExpr", + "case_", + "case", + "cast", + "sort", + "negative", + "in_list", + "inList", + "scalar_function", + "scalarFunction", + "try_cast", + "tryCast", + "window_expr", + "windowExpr", + "scalar_udf", + "scalarUdf", + "like_expr", + "likeExpr", + "get_indexed_field_expr", + "getIndexedFieldExpr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - ArrowType, + Column, + Literal, + BinaryExpr, + AggregateExpr, + IsNullExpr, + IsNotNullExpr, + NotExpr, + Case, + Cast, + Sort, + Negative, + InList, + ScalarFunction, + TryCast, + WindowExpr, + ScalarUdf, + LikeExpr, + GetIndexedFieldExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -14907,8 +17707,24 @@ impl<'de> serde::Deserialize<'de> for PhysicalCastNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "column" => Ok(GeneratedField::Column), + "literal" => Ok(GeneratedField::Literal), + "binaryExpr" | "binary_expr" => Ok(GeneratedField::BinaryExpr), + "aggregateExpr" | "aggregate_expr" => Ok(GeneratedField::AggregateExpr), + "isNullExpr" | "is_null_expr" => Ok(GeneratedField::IsNullExpr), + "isNotNullExpr" | "is_not_null_expr" => Ok(GeneratedField::IsNotNullExpr), + "notExpr" | "not_expr" => Ok(GeneratedField::NotExpr), + "case" | "case_" => Ok(GeneratedField::Case), + "cast" => Ok(GeneratedField::Cast), + "sort" => Ok(GeneratedField::Sort), + "negative" => Ok(GeneratedField::Negative), + "inList" | "in_list" => Ok(GeneratedField::InList), + "scalarFunction" | "scalar_function" => Ok(GeneratedField::ScalarFunction), + "tryCast" | "try_cast" => Ok(GeneratedField::TryCast), + "windowExpr" | "window_expr" => Ok(GeneratedField::WindowExpr), + "scalarUdf" | "scalar_udf" => Ok(GeneratedField::ScalarUdf), + "likeExpr" | "like_expr" => Ok(GeneratedField::LikeExpr), + "getIndexedFieldExpr" | "get_indexed_field_expr" => Ok(GeneratedField::GetIndexedFieldExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -14918,44 +17734,156 @@ impl<'de> serde::Deserialize<'de> for PhysicalCastNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalCastNode; + type Value = PhysicalExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalCastNode") + formatter.write_str("struct datafusion.PhysicalExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; - let mut arrow_type__ = None; + let mut expr_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Column => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("column")); } - expr__ = map_.next_value()?; + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Column) +; + } + GeneratedField::Literal => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("literal")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Literal) +; + } + GeneratedField::BinaryExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("binaryExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::BinaryExpr) +; + } + GeneratedField::AggregateExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("aggregateExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::AggregateExpr) +; + } + GeneratedField::IsNullExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isNullExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::IsNullExpr) +; + } + GeneratedField::IsNotNullExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isNotNullExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::IsNotNullExpr) +; + } + GeneratedField::NotExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("notExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::NotExpr) +; + } + GeneratedField::Case => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("case")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Case) +; + } + GeneratedField::Cast => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("cast")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Cast) +; + } + GeneratedField::Sort => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("sort")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Sort) +; + } + GeneratedField::Negative => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("negative")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Negative) +; + } + GeneratedField::InList => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("inList")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::InList) +; + } + GeneratedField::ScalarFunction => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarFunction")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarFunction) +; + } + GeneratedField::TryCast => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("tryCast")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::TryCast) +; + } + GeneratedField::WindowExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("windowExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::WindowExpr) +; + } + GeneratedField::ScalarUdf => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarUdf")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarUdf) +; + } + GeneratedField::LikeExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("likeExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::LikeExpr) +; } - GeneratedField::ArrowType => { - if arrow_type__.is_some() { - return Err(serde::de::Error::duplicate_field("arrowType")); + GeneratedField::GetIndexedFieldExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("getIndexedFieldExpr")); } - arrow_type__ = map_.next_value()?; + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::GetIndexedFieldExpr) +; } } } - Ok(PhysicalCastNode { - expr: expr__, - arrow_type: arrow_type__, + Ok(PhysicalExprNode { + expr_type: expr_type__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalCastNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalColumn { +impl serde::Serialize for PhysicalExtensionNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -14963,37 +17891,38 @@ impl serde::Serialize for PhysicalColumn { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.name.is_empty() { + if !self.node.is_empty() { len += 1; } - if self.index != 0 { + if !self.inputs.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalColumn", len)?; - if !self.name.is_empty() { - struct_ser.serialize_field("name", &self.name)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalExtensionNode", len)?; + if !self.node.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("node", pbjson::private::base64::encode(&self.node).as_str())?; } - if self.index != 0 { - struct_ser.serialize_field("index", &self.index)?; + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalColumn { +impl<'de> serde::Deserialize<'de> for PhysicalExtensionNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", - "index", + "node", + "inputs", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, - Index, + Node, + Inputs, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -15015,8 +17944,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalColumn { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), - "index" => Ok(GeneratedField::Index), + "node" => Ok(GeneratedField::Node), + "inputs" => Ok(GeneratedField::Inputs), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -15026,46 +17955,46 @@ impl<'de> serde::Deserialize<'de> for PhysicalColumn { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalColumn; + type Value = PhysicalExtensionNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalColumn") + formatter.write_str("struct datafusion.PhysicalExtensionNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; - let mut index__ = None; + let mut node__ = None; + let mut inputs__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); + GeneratedField::Node => { + if node__.is_some() { + return Err(serde::de::Error::duplicate_field("node")); } - name__ = Some(map_.next_value()?); + node__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; } - GeneratedField::Index => { - if index__.is_some() { - return Err(serde::de::Error::duplicate_field("index")); + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); } - index__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + inputs__ = Some(map_.next_value()?); } } } - Ok(PhysicalColumn { - name: name__.unwrap_or_default(), - index: index__.unwrap_or_default(), + Ok(PhysicalExtensionNode { + node: node__.unwrap_or_default(), + inputs: inputs__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalColumn", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalExtensionNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalDateTimeIntervalExprNode { +impl serde::Serialize for PhysicalGetIndexedFieldExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -15073,45 +18002,54 @@ impl serde::Serialize for PhysicalDateTimeIntervalExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.l.is_some() { - len += 1; - } - if self.r.is_some() { + if self.arg.is_some() { len += 1; } - if !self.op.is_empty() { + if self.field.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalDateTimeIntervalExprNode", len)?; - if let Some(v) = self.l.as_ref() { - struct_ser.serialize_field("l", v)?; - } - if let Some(v) = self.r.as_ref() { - struct_ser.serialize_field("r", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalGetIndexedFieldExprNode", len)?; + if let Some(v) = self.arg.as_ref() { + struct_ser.serialize_field("arg", v)?; } - if !self.op.is_empty() { - struct_ser.serialize_field("op", &self.op)?; + if let Some(v) = self.field.as_ref() { + match v { + physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr(v) => { + struct_ser.serialize_field("namedStructFieldExpr", v)?; + } + physical_get_indexed_field_expr_node::Field::ListIndexExpr(v) => { + struct_ser.serialize_field("listIndexExpr", v)?; + } + physical_get_indexed_field_expr_node::Field::ListRangeExpr(v) => { + struct_ser.serialize_field("listRangeExpr", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalDateTimeIntervalExprNode { +impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "l", - "r", - "op", + "arg", + "named_struct_field_expr", + "namedStructFieldExpr", + "list_index_expr", + "listIndexExpr", + "list_range_expr", + "listRangeExpr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - L, - R, - Op, + Arg, + NamedStructFieldExpr, + ListIndexExpr, + ListRangeExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -15133,9 +18071,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalDateTimeIntervalExprNode { E: serde::de::Error, { match value { - "l" => Ok(GeneratedField::L), - "r" => Ok(GeneratedField::R), - "op" => Ok(GeneratedField::Op), + "arg" => Ok(GeneratedField::Arg), + "namedStructFieldExpr" | "named_struct_field_expr" => Ok(GeneratedField::NamedStructFieldExpr), + "listIndexExpr" | "list_index_expr" => Ok(GeneratedField::ListIndexExpr), + "listRangeExpr" | "list_range_expr" => Ok(GeneratedField::ListRangeExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -15145,184 +18084,100 @@ impl<'de> serde::Deserialize<'de> for PhysicalDateTimeIntervalExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalDateTimeIntervalExprNode; + type Value = PhysicalGetIndexedFieldExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalDateTimeIntervalExprNode") + formatter.write_str("struct datafusion.PhysicalGetIndexedFieldExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut l__ = None; - let mut r__ = None; - let mut op__ = None; + let mut arg__ = None; + let mut field__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::L => { - if l__.is_some() { - return Err(serde::de::Error::duplicate_field("l")); - } - l__ = map_.next_value()?; - } - GeneratedField::R => { - if r__.is_some() { - return Err(serde::de::Error::duplicate_field("r")); + GeneratedField::Arg => { + if arg__.is_some() { + return Err(serde::de::Error::duplicate_field("arg")); } - r__ = map_.next_value()?; + arg__ = map_.next_value()?; } - GeneratedField::Op => { - if op__.is_some() { - return Err(serde::de::Error::duplicate_field("op")); + GeneratedField::NamedStructFieldExpr => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("namedStructFieldExpr")); } - op__ = Some(map_.next_value()?); - } - } - } - Ok(PhysicalDateTimeIntervalExprNode { - l: l__, - r: r__, - op: op__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.PhysicalDateTimeIntervalExprNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for PhysicalExprNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expr_type.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalExprNode", len)?; - if let Some(v) = self.expr_type.as_ref() { - match v { - physical_expr_node::ExprType::Column(v) => { - struct_ser.serialize_field("column", v)?; - } - physical_expr_node::ExprType::Literal(v) => { - struct_ser.serialize_field("literal", v)?; - } - physical_expr_node::ExprType::BinaryExpr(v) => { - struct_ser.serialize_field("binaryExpr", v)?; - } - physical_expr_node::ExprType::AggregateExpr(v) => { - struct_ser.serialize_field("aggregateExpr", v)?; - } - physical_expr_node::ExprType::IsNullExpr(v) => { - struct_ser.serialize_field("isNullExpr", v)?; - } - physical_expr_node::ExprType::IsNotNullExpr(v) => { - struct_ser.serialize_field("isNotNullExpr", v)?; - } - physical_expr_node::ExprType::NotExpr(v) => { - struct_ser.serialize_field("notExpr", v)?; - } - physical_expr_node::ExprType::Case(v) => { - struct_ser.serialize_field("case", v)?; - } - physical_expr_node::ExprType::Cast(v) => { - struct_ser.serialize_field("cast", v)?; - } - physical_expr_node::ExprType::Sort(v) => { - struct_ser.serialize_field("sort", v)?; - } - physical_expr_node::ExprType::Negative(v) => { - struct_ser.serialize_field("negative", v)?; - } - physical_expr_node::ExprType::InList(v) => { - struct_ser.serialize_field("inList", v)?; - } - physical_expr_node::ExprType::ScalarFunction(v) => { - struct_ser.serialize_field("scalarFunction", v)?; - } - physical_expr_node::ExprType::TryCast(v) => { - struct_ser.serialize_field("tryCast", v)?; - } - physical_expr_node::ExprType::WindowExpr(v) => { - struct_ser.serialize_field("windowExpr", v)?; - } - physical_expr_node::ExprType::ScalarUdf(v) => { - struct_ser.serialize_field("scalarUdf", v)?; - } - physical_expr_node::ExprType::LikeExpr(v) => { - struct_ser.serialize_field("likeExpr", v)?; - } - physical_expr_node::ExprType::GetIndexedFieldExpr(v) => { - struct_ser.serialize_field("getIndexedFieldExpr", v)?; + field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr) +; + } + GeneratedField::ListIndexExpr => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("listIndexExpr")); + } + field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::ListIndexExpr) +; + } + GeneratedField::ListRangeExpr => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("listRangeExpr")); + } + field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::ListRangeExpr) +; + } + } } + Ok(PhysicalGetIndexedFieldExprNode { + arg: arg__, + field: field__, + }) } } + deserializer.deserialize_struct("datafusion.PhysicalGetIndexedFieldExprNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for PhysicalHashRepartition { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.hash_expr.is_empty() { + len += 1; + } + if self.partition_count != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalHashRepartition", len)?; + if !self.hash_expr.is_empty() { + struct_ser.serialize_field("hashExpr", &self.hash_expr)?; + } + if self.partition_count != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("partitionCount", ToString::to_string(&self.partition_count).as_str())?; + } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalExprNode { +impl<'de> serde::Deserialize<'de> for PhysicalHashRepartition { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "column", - "literal", - "binary_expr", - "binaryExpr", - "aggregate_expr", - "aggregateExpr", - "is_null_expr", - "isNullExpr", - "is_not_null_expr", - "isNotNullExpr", - "not_expr", - "notExpr", - "case_", - "case", - "cast", - "sort", - "negative", - "in_list", - "inList", - "scalar_function", - "scalarFunction", - "try_cast", - "tryCast", - "window_expr", - "windowExpr", - "scalar_udf", - "scalarUdf", - "like_expr", - "likeExpr", - "get_indexed_field_expr", - "getIndexedFieldExpr", + "hash_expr", + "hashExpr", + "partition_count", + "partitionCount", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Column, - Literal, - BinaryExpr, - AggregateExpr, - IsNullExpr, - IsNotNullExpr, - NotExpr, - Case, - Cast, - Sort, - Negative, - InList, - ScalarFunction, - TryCast, - WindowExpr, - ScalarUdf, - LikeExpr, - GetIndexedFieldExpr, + HashExpr, + PartitionCount, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -15344,24 +18199,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { E: serde::de::Error, { match value { - "column" => Ok(GeneratedField::Column), - "literal" => Ok(GeneratedField::Literal), - "binaryExpr" | "binary_expr" => Ok(GeneratedField::BinaryExpr), - "aggregateExpr" | "aggregate_expr" => Ok(GeneratedField::AggregateExpr), - "isNullExpr" | "is_null_expr" => Ok(GeneratedField::IsNullExpr), - "isNotNullExpr" | "is_not_null_expr" => Ok(GeneratedField::IsNotNullExpr), - "notExpr" | "not_expr" => Ok(GeneratedField::NotExpr), - "case" | "case_" => Ok(GeneratedField::Case), - "cast" => Ok(GeneratedField::Cast), - "sort" => Ok(GeneratedField::Sort), - "negative" => Ok(GeneratedField::Negative), - "inList" | "in_list" => Ok(GeneratedField::InList), - "scalarFunction" | "scalar_function" => Ok(GeneratedField::ScalarFunction), - "tryCast" | "try_cast" => Ok(GeneratedField::TryCast), - "windowExpr" | "window_expr" => Ok(GeneratedField::WindowExpr), - "scalarUdf" | "scalar_udf" => Ok(GeneratedField::ScalarUdf), - "likeExpr" | "like_expr" => Ok(GeneratedField::LikeExpr), - "getIndexedFieldExpr" | "get_indexed_field_expr" => Ok(GeneratedField::GetIndexedFieldExpr), + "hashExpr" | "hash_expr" => Ok(GeneratedField::HashExpr), + "partitionCount" | "partition_count" => Ok(GeneratedField::PartitionCount), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -15371,156 +18210,46 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalExprNode; + type Value = PhysicalHashRepartition; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalExprNode") + formatter.write_str("struct datafusion.PhysicalHashRepartition") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr_type__ = None; + let mut hash_expr__ = None; + let mut partition_count__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Column => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("column")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Column) -; - } - GeneratedField::Literal => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("literal")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Literal) -; - } - GeneratedField::BinaryExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("binaryExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::BinaryExpr) -; - } - GeneratedField::AggregateExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("aggregateExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::AggregateExpr) -; - } - GeneratedField::IsNullExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isNullExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::IsNullExpr) -; - } - GeneratedField::IsNotNullExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isNotNullExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::IsNotNullExpr) -; - } - GeneratedField::NotExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("notExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::NotExpr) -; - } - GeneratedField::Case => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("case")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Case) -; - } - GeneratedField::Cast => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("cast")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Cast) -; - } - GeneratedField::Sort => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("sort")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Sort) -; - } - GeneratedField::Negative => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("negative")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Negative) -; - } - GeneratedField::InList => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("inList")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::InList) -; - } - GeneratedField::ScalarFunction => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("scalarFunction")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarFunction) -; - } - GeneratedField::TryCast => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("tryCast")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::TryCast) -; - } - GeneratedField::WindowExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("windowExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::WindowExpr) -; - } - GeneratedField::ScalarUdf => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("scalarUdf")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarUdf) -; - } - GeneratedField::LikeExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("likeExpr")); + GeneratedField::HashExpr => { + if hash_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("hashExpr")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::LikeExpr) -; + hash_expr__ = Some(map_.next_value()?); } - GeneratedField::GetIndexedFieldExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("getIndexedFieldExpr")); + GeneratedField::PartitionCount => { + if partition_count__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionCount")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::GetIndexedFieldExpr) -; + partition_count__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(PhysicalExprNode { - expr_type: expr_type__, + Ok(PhysicalHashRepartition { + hash_expr: hash_expr__.unwrap_or_default(), + partition_count: partition_count__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalHashRepartition", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalExtensionNode { +impl serde::Serialize for PhysicalInListNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -15528,38 +18257,45 @@ impl serde::Serialize for PhysicalExtensionNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.node.is_empty() { + if self.expr.is_some() { len += 1; } - if !self.inputs.is_empty() { + if !self.list.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalExtensionNode", len)?; - if !self.node.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("node", pbjson::private::base64::encode(&self.node).as_str())?; + if self.negated { + len += 1; } - if !self.inputs.is_empty() { - struct_ser.serialize_field("inputs", &self.inputs)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalInListNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if !self.list.is_empty() { + struct_ser.serialize_field("list", &self.list)?; + } + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalExtensionNode { +impl<'de> serde::Deserialize<'de> for PhysicalInListNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "node", - "inputs", + "expr", + "list", + "negated", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Node, - Inputs, + Expr, + List, + Negated, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -15581,8 +18317,9 @@ impl<'de> serde::Deserialize<'de> for PhysicalExtensionNode { E: serde::de::Error, { match value { - "node" => Ok(GeneratedField::Node), - "inputs" => Ok(GeneratedField::Inputs), + "expr" => Ok(GeneratedField::Expr), + "list" => Ok(GeneratedField::List), + "negated" => Ok(GeneratedField::Negated), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -15592,46 +18329,52 @@ impl<'de> serde::Deserialize<'de> for PhysicalExtensionNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalExtensionNode; + type Value = PhysicalInListNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalExtensionNode") + formatter.write_str("struct datafusion.PhysicalInListNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut node__ = None; - let mut inputs__ = None; + let mut expr__ = None; + let mut list__ = None; + let mut negated__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Node => { - if node__.is_some() { - return Err(serde::de::Error::duplicate_field("node")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - node__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) - ; + expr__ = map_.next_value()?; } - GeneratedField::Inputs => { - if inputs__.is_some() { - return Err(serde::de::Error::duplicate_field("inputs")); + GeneratedField::List => { + if list__.is_some() { + return Err(serde::de::Error::duplicate_field("list")); } - inputs__ = Some(map_.next_value()?); + list__ = Some(map_.next_value()?); + } + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); + } + negated__ = Some(map_.next_value()?); } } } - Ok(PhysicalExtensionNode { - node: node__.unwrap_or_default(), - inputs: inputs__.unwrap_or_default(), + Ok(PhysicalInListNode { + expr: expr__, + list: list__.unwrap_or_default(), + negated: negated__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalExtensionNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalInListNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalGetIndexedFieldExprNode { +impl serde::Serialize for PhysicalIsNotNull { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -15639,54 +18382,29 @@ impl serde::Serialize for PhysicalGetIndexedFieldExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.arg.is_some() { - len += 1; - } - if self.field.is_some() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalGetIndexedFieldExprNode", len)?; - if let Some(v) = self.arg.as_ref() { - struct_ser.serialize_field("arg", v)?; - } - if let Some(v) = self.field.as_ref() { - match v { - physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr(v) => { - struct_ser.serialize_field("namedStructFieldExpr", v)?; - } - physical_get_indexed_field_expr_node::Field::ListIndexExpr(v) => { - struct_ser.serialize_field("listIndexExpr", v)?; - } - physical_get_indexed_field_expr_node::Field::ListRangeExpr(v) => { - struct_ser.serialize_field("listRangeExpr", v)?; - } - } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalIsNotNull", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { +impl<'de> serde::Deserialize<'de> for PhysicalIsNotNull { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "arg", - "named_struct_field_expr", - "namedStructFieldExpr", - "list_index_expr", - "listIndexExpr", - "list_range_expr", - "listRangeExpr", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Arg, - NamedStructFieldExpr, - ListIndexExpr, - ListRangeExpr, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -15708,10 +18426,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { E: serde::de::Error, { match value { - "arg" => Ok(GeneratedField::Arg), - "namedStructFieldExpr" | "named_struct_field_expr" => Ok(GeneratedField::NamedStructFieldExpr), - "listIndexExpr" | "list_index_expr" => Ok(GeneratedField::ListIndexExpr), - "listRangeExpr" | "list_range_expr" => Ok(GeneratedField::ListRangeExpr), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -15721,59 +18436,36 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalGetIndexedFieldExprNode; + type Value = PhysicalIsNotNull; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalGetIndexedFieldExprNode") + formatter.write_str("struct datafusion.PhysicalIsNotNull") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut arg__ = None; - let mut field__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Arg => { - if arg__.is_some() { - return Err(serde::de::Error::duplicate_field("arg")); - } - arg__ = map_.next_value()?; - } - GeneratedField::NamedStructFieldExpr => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("namedStructFieldExpr")); - } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr) -; - } - GeneratedField::ListIndexExpr => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("listIndexExpr")); - } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::ListIndexExpr) -; - } - GeneratedField::ListRangeExpr => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("listRangeExpr")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::ListRangeExpr) -; + expr__ = map_.next_value()?; } } } - Ok(PhysicalGetIndexedFieldExprNode { - arg: arg__, - field: field__, + Ok(PhysicalIsNotNull { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalGetIndexedFieldExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalIsNotNull", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalHashRepartition { +impl serde::Serialize for PhysicalIsNull { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -15781,40 +18473,29 @@ impl serde::Serialize for PhysicalHashRepartition { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.hash_expr.is_empty() { - len += 1; - } - if self.partition_count != 0 { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalHashRepartition", len)?; - if !self.hash_expr.is_empty() { - struct_ser.serialize_field("hashExpr", &self.hash_expr)?; - } - if self.partition_count != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("partitionCount", ToString::to_string(&self.partition_count).as_str())?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalIsNull", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalHashRepartition { +impl<'de> serde::Deserialize<'de> for PhysicalIsNull { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "hash_expr", - "hashExpr", - "partition_count", - "partitionCount", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - HashExpr, - PartitionCount, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -15836,8 +18517,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalHashRepartition { E: serde::de::Error, { match value { - "hashExpr" | "hash_expr" => Ok(GeneratedField::HashExpr), - "partitionCount" | "partition_count" => Ok(GeneratedField::PartitionCount), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -15847,46 +18527,36 @@ impl<'de> serde::Deserialize<'de> for PhysicalHashRepartition { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalHashRepartition; + type Value = PhysicalIsNull; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalHashRepartition") + formatter.write_str("struct datafusion.PhysicalIsNull") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut hash_expr__ = None; - let mut partition_count__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::HashExpr => { - if hash_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("hashExpr")); - } - hash_expr__ = Some(map_.next_value()?); - } - GeneratedField::PartitionCount => { - if partition_count__.is_some() { - return Err(serde::de::Error::duplicate_field("partitionCount")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - partition_count__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + expr__ = map_.next_value()?; } } } - Ok(PhysicalHashRepartition { - hash_expr: hash_expr__.unwrap_or_default(), - partition_count: partition_count__.unwrap_or_default(), + Ok(PhysicalIsNull { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalHashRepartition", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalIsNull", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalInListNode { +impl serde::Serialize for PhysicalLikeExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -15894,45 +18564,54 @@ impl serde::Serialize for PhysicalInListNode { { use serde::ser::SerializeStruct; let mut len = 0; + if self.negated { + len += 1; + } + if self.case_insensitive { + len += 1; + } if self.expr.is_some() { len += 1; } - if !self.list.is_empty() { + if self.pattern.is_some() { len += 1; } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalLikeExprNode", len)?; if self.negated { - len += 1; + struct_ser.serialize_field("negated", &self.negated)?; + } + if self.case_insensitive { + struct_ser.serialize_field("caseInsensitive", &self.case_insensitive)?; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalInListNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } - if !self.list.is_empty() { - struct_ser.serialize_field("list", &self.list)?; - } - if self.negated { - struct_ser.serialize_field("negated", &self.negated)?; + if let Some(v) = self.pattern.as_ref() { + struct_ser.serialize_field("pattern", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalInListNode { +impl<'de> serde::Deserialize<'de> for PhysicalLikeExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "list", "negated", + "case_insensitive", + "caseInsensitive", + "expr", + "pattern", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - List, Negated, + CaseInsensitive, + Expr, + Pattern, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -15954,9 +18633,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalInListNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "list" => Ok(GeneratedField::List), "negated" => Ok(GeneratedField::Negated), + "caseInsensitive" | "case_insensitive" => Ok(GeneratedField::CaseInsensitive), + "expr" => Ok(GeneratedField::Expr), + "pattern" => Ok(GeneratedField::Pattern), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -15966,52 +18646,60 @@ impl<'de> serde::Deserialize<'de> for PhysicalInListNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalInListNode; + type Value = PhysicalLikeExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalInListNode") + formatter.write_str("struct datafusion.PhysicalLikeExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; - let mut list__ = None; let mut negated__ = None; + let mut case_insensitive__ = None; + let mut expr__ = None; + let mut pattern__ = None; while let Some(k) = map_.next_key()? { match k { + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); + } + negated__ = Some(map_.next_value()?); + } + GeneratedField::CaseInsensitive => { + if case_insensitive__.is_some() { + return Err(serde::de::Error::duplicate_field("caseInsensitive")); + } + case_insensitive__ = Some(map_.next_value()?); + } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } expr__ = map_.next_value()?; } - GeneratedField::List => { - if list__.is_some() { - return Err(serde::de::Error::duplicate_field("list")); - } - list__ = Some(map_.next_value()?); - } - GeneratedField::Negated => { - if negated__.is_some() { - return Err(serde::de::Error::duplicate_field("negated")); + GeneratedField::Pattern => { + if pattern__.is_some() { + return Err(serde::de::Error::duplicate_field("pattern")); } - negated__ = Some(map_.next_value()?); + pattern__ = map_.next_value()?; } } } - Ok(PhysicalInListNode { - expr: expr__, - list: list__.unwrap_or_default(), + Ok(PhysicalLikeExprNode { negated: negated__.unwrap_or_default(), + case_insensitive: case_insensitive__.unwrap_or_default(), + expr: expr__, + pattern: pattern__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalInListNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalLikeExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalIsNotNull { +impl serde::Serialize for PhysicalNegativeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -16022,14 +18710,14 @@ impl serde::Serialize for PhysicalIsNotNull { if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalIsNotNull", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalNegativeNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalIsNotNull { +impl<'de> serde::Deserialize<'de> for PhysicalNegativeNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -16073,13 +18761,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalIsNotNull { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalIsNotNull; + type Value = PhysicalNegativeNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalIsNotNull") + formatter.write_str("struct datafusion.PhysicalNegativeNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -16094,15 +18782,15 @@ impl<'de> serde::Deserialize<'de> for PhysicalIsNotNull { } } } - Ok(PhysicalIsNotNull { + Ok(PhysicalNegativeNode { expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalIsNotNull", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalNegativeNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalIsNull { +impl serde::Serialize for PhysicalNot { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -16113,14 +18801,14 @@ impl serde::Serialize for PhysicalIsNull { if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalIsNull", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalNot", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalIsNull { +impl<'de> serde::Deserialize<'de> for PhysicalNot { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -16164,13 +18852,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalIsNull { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalIsNull; + type Value = PhysicalNot; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalIsNull") + formatter.write_str("struct datafusion.PhysicalNot") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -16185,15 +18873,15 @@ impl<'de> serde::Deserialize<'de> for PhysicalIsNull { } } } - Ok(PhysicalIsNull { + Ok(PhysicalNot { expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalIsNull", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalNot", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalLikeExprNode { +impl serde::Serialize for PhysicalPlanNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -16201,54 +18889,183 @@ impl serde::Serialize for PhysicalLikeExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.negated { - len += 1; - } - if self.case_insensitive { - len += 1; - } - if self.expr.is_some() { - len += 1; - } - if self.pattern.is_some() { + if self.physical_plan_type.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalLikeExprNode", len)?; - if self.negated { - struct_ser.serialize_field("negated", &self.negated)?; - } - if self.case_insensitive { - struct_ser.serialize_field("caseInsensitive", &self.case_insensitive)?; - } - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - if let Some(v) = self.pattern.as_ref() { - struct_ser.serialize_field("pattern", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalPlanNode", len)?; + if let Some(v) = self.physical_plan_type.as_ref() { + match v { + physical_plan_node::PhysicalPlanType::ParquetScan(v) => { + struct_ser.serialize_field("parquetScan", v)?; + } + physical_plan_node::PhysicalPlanType::CsvScan(v) => { + struct_ser.serialize_field("csvScan", v)?; + } + physical_plan_node::PhysicalPlanType::Empty(v) => { + struct_ser.serialize_field("empty", v)?; + } + physical_plan_node::PhysicalPlanType::Projection(v) => { + struct_ser.serialize_field("projection", v)?; + } + physical_plan_node::PhysicalPlanType::GlobalLimit(v) => { + struct_ser.serialize_field("globalLimit", v)?; + } + physical_plan_node::PhysicalPlanType::LocalLimit(v) => { + struct_ser.serialize_field("localLimit", v)?; + } + physical_plan_node::PhysicalPlanType::Aggregate(v) => { + struct_ser.serialize_field("aggregate", v)?; + } + physical_plan_node::PhysicalPlanType::HashJoin(v) => { + struct_ser.serialize_field("hashJoin", v)?; + } + physical_plan_node::PhysicalPlanType::Sort(v) => { + struct_ser.serialize_field("sort", v)?; + } + physical_plan_node::PhysicalPlanType::CoalesceBatches(v) => { + struct_ser.serialize_field("coalesceBatches", v)?; + } + physical_plan_node::PhysicalPlanType::Filter(v) => { + struct_ser.serialize_field("filter", v)?; + } + physical_plan_node::PhysicalPlanType::Merge(v) => { + struct_ser.serialize_field("merge", v)?; + } + physical_plan_node::PhysicalPlanType::Repartition(v) => { + struct_ser.serialize_field("repartition", v)?; + } + physical_plan_node::PhysicalPlanType::Window(v) => { + struct_ser.serialize_field("window", v)?; + } + physical_plan_node::PhysicalPlanType::CrossJoin(v) => { + struct_ser.serialize_field("crossJoin", v)?; + } + physical_plan_node::PhysicalPlanType::AvroScan(v) => { + struct_ser.serialize_field("avroScan", v)?; + } + physical_plan_node::PhysicalPlanType::Extension(v) => { + struct_ser.serialize_field("extension", v)?; + } + physical_plan_node::PhysicalPlanType::Union(v) => { + struct_ser.serialize_field("union", v)?; + } + physical_plan_node::PhysicalPlanType::Explain(v) => { + struct_ser.serialize_field("explain", v)?; + } + physical_plan_node::PhysicalPlanType::SortPreservingMerge(v) => { + struct_ser.serialize_field("sortPreservingMerge", v)?; + } + physical_plan_node::PhysicalPlanType::NestedLoopJoin(v) => { + struct_ser.serialize_field("nestedLoopJoin", v)?; + } + physical_plan_node::PhysicalPlanType::Analyze(v) => { + struct_ser.serialize_field("analyze", v)?; + } + physical_plan_node::PhysicalPlanType::JsonSink(v) => { + struct_ser.serialize_field("jsonSink", v)?; + } + physical_plan_node::PhysicalPlanType::SymmetricHashJoin(v) => { + struct_ser.serialize_field("symmetricHashJoin", v)?; + } + physical_plan_node::PhysicalPlanType::Interleave(v) => { + struct_ser.serialize_field("interleave", v)?; + } + physical_plan_node::PhysicalPlanType::PlaceholderRow(v) => { + struct_ser.serialize_field("placeholderRow", v)?; + } + physical_plan_node::PhysicalPlanType::CsvSink(v) => { + struct_ser.serialize_field("csvSink", v)?; + } + physical_plan_node::PhysicalPlanType::ParquetSink(v) => { + struct_ser.serialize_field("parquetSink", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalLikeExprNode { +impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "negated", - "case_insensitive", - "caseInsensitive", - "expr", - "pattern", + "parquet_scan", + "parquetScan", + "csv_scan", + "csvScan", + "empty", + "projection", + "global_limit", + "globalLimit", + "local_limit", + "localLimit", + "aggregate", + "hash_join", + "hashJoin", + "sort", + "coalesce_batches", + "coalesceBatches", + "filter", + "merge", + "repartition", + "window", + "cross_join", + "crossJoin", + "avro_scan", + "avroScan", + "extension", + "union", + "explain", + "sort_preserving_merge", + "sortPreservingMerge", + "nested_loop_join", + "nestedLoopJoin", + "analyze", + "json_sink", + "jsonSink", + "symmetric_hash_join", + "symmetricHashJoin", + "interleave", + "placeholder_row", + "placeholderRow", + "csv_sink", + "csvSink", + "parquet_sink", + "parquetSink", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Negated, - CaseInsensitive, - Expr, - Pattern, + ParquetScan, + CsvScan, + Empty, + Projection, + GlobalLimit, + LocalLimit, + Aggregate, + HashJoin, + Sort, + CoalesceBatches, + Filter, + Merge, + Repartition, + Window, + CrossJoin, + AvroScan, + Extension, + Union, + Explain, + SortPreservingMerge, + NestedLoopJoin, + Analyze, + JsonSink, + SymmetricHashJoin, + Interleave, + PlaceholderRow, + CsvSink, + ParquetSink, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16270,10 +19087,34 @@ impl<'de> serde::Deserialize<'de> for PhysicalLikeExprNode { E: serde::de::Error, { match value { - "negated" => Ok(GeneratedField::Negated), - "caseInsensitive" | "case_insensitive" => Ok(GeneratedField::CaseInsensitive), - "expr" => Ok(GeneratedField::Expr), - "pattern" => Ok(GeneratedField::Pattern), + "parquetScan" | "parquet_scan" => Ok(GeneratedField::ParquetScan), + "csvScan" | "csv_scan" => Ok(GeneratedField::CsvScan), + "empty" => Ok(GeneratedField::Empty), + "projection" => Ok(GeneratedField::Projection), + "globalLimit" | "global_limit" => Ok(GeneratedField::GlobalLimit), + "localLimit" | "local_limit" => Ok(GeneratedField::LocalLimit), + "aggregate" => Ok(GeneratedField::Aggregate), + "hashJoin" | "hash_join" => Ok(GeneratedField::HashJoin), + "sort" => Ok(GeneratedField::Sort), + "coalesceBatches" | "coalesce_batches" => Ok(GeneratedField::CoalesceBatches), + "filter" => Ok(GeneratedField::Filter), + "merge" => Ok(GeneratedField::Merge), + "repartition" => Ok(GeneratedField::Repartition), + "window" => Ok(GeneratedField::Window), + "crossJoin" | "cross_join" => Ok(GeneratedField::CrossJoin), + "avroScan" | "avro_scan" => Ok(GeneratedField::AvroScan), + "extension" => Ok(GeneratedField::Extension), + "union" => Ok(GeneratedField::Union), + "explain" => Ok(GeneratedField::Explain), + "sortPreservingMerge" | "sort_preserving_merge" => Ok(GeneratedField::SortPreservingMerge), + "nestedLoopJoin" | "nested_loop_join" => Ok(GeneratedField::NestedLoopJoin), + "analyze" => Ok(GeneratedField::Analyze), + "jsonSink" | "json_sink" => Ok(GeneratedField::JsonSink), + "symmetricHashJoin" | "symmetric_hash_join" => Ok(GeneratedField::SymmetricHashJoin), + "interleave" => Ok(GeneratedField::Interleave), + "placeholderRow" | "placeholder_row" => Ok(GeneratedField::PlaceholderRow), + "csvSink" | "csv_sink" => Ok(GeneratedField::CsvSink), + "parquetSink" | "parquet_sink" => Ok(GeneratedField::ParquetSink), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16283,60 +19124,226 @@ impl<'de> serde::Deserialize<'de> for PhysicalLikeExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalLikeExprNode; + type Value = PhysicalPlanNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalLikeExprNode") + formatter.write_str("struct datafusion.PhysicalPlanNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut negated__ = None; - let mut case_insensitive__ = None; - let mut expr__ = None; - let mut pattern__ = None; + let mut physical_plan_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Negated => { - if negated__.is_some() { - return Err(serde::de::Error::duplicate_field("negated")); + GeneratedField::ParquetScan => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("parquetScan")); } - negated__ = Some(map_.next_value()?); + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ParquetScan) +; } - GeneratedField::CaseInsensitive => { - if case_insensitive__.is_some() { - return Err(serde::de::Error::duplicate_field("caseInsensitive")); + GeneratedField::CsvScan => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("csvScan")); } - case_insensitive__ = Some(map_.next_value()?); + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CsvScan) +; } - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Empty => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("empty")); } - expr__ = map_.next_value()?; + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Empty) +; } - GeneratedField::Pattern => { - if pattern__.is_some() { - return Err(serde::de::Error::duplicate_field("pattern")); + GeneratedField::Projection => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); } - pattern__ = map_.next_value()?; + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Projection) +; + } + GeneratedField::GlobalLimit => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("globalLimit")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::GlobalLimit) +; + } + GeneratedField::LocalLimit => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("localLimit")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::LocalLimit) +; + } + GeneratedField::Aggregate => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("aggregate")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Aggregate) +; + } + GeneratedField::HashJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("hashJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::HashJoin) +; + } + GeneratedField::Sort => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("sort")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Sort) +; + } + GeneratedField::CoalesceBatches => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("coalesceBatches")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CoalesceBatches) +; + } + GeneratedField::Filter => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Filter) +; + } + GeneratedField::Merge => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("merge")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Merge) +; + } + GeneratedField::Repartition => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("repartition")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Repartition) +; + } + GeneratedField::Window => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("window")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Window) +; + } + GeneratedField::CrossJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("crossJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CrossJoin) +; + } + GeneratedField::AvroScan => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("avroScan")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::AvroScan) +; + } + GeneratedField::Extension => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("extension")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Extension) +; + } + GeneratedField::Union => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("union")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Union) +; + } + GeneratedField::Explain => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("explain")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Explain) +; + } + GeneratedField::SortPreservingMerge => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("sortPreservingMerge")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SortPreservingMerge) +; + } + GeneratedField::NestedLoopJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("nestedLoopJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::NestedLoopJoin) +; + } + GeneratedField::Analyze => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("analyze")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Analyze) +; + } + GeneratedField::JsonSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("jsonSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::JsonSink) +; + } + GeneratedField::SymmetricHashJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("symmetricHashJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SymmetricHashJoin) +; + } + GeneratedField::Interleave => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("interleave")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Interleave) +; + } + GeneratedField::PlaceholderRow => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("placeholderRow")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::PlaceholderRow) +; + } + GeneratedField::CsvSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("csvSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CsvSink) +; + } + GeneratedField::ParquetSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("parquetSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ParquetSink) +; } } } - Ok(PhysicalLikeExprNode { - negated: negated__.unwrap_or_default(), - case_insensitive: case_insensitive__.unwrap_or_default(), - expr: expr__, - pattern: pattern__, + Ok(PhysicalPlanNode { + physical_plan_type: physical_plan_type__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalLikeExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalPlanNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalNegativeNode { +impl serde::Serialize for PhysicalScalarFunctionNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -16344,29 +19351,56 @@ impl serde::Serialize for PhysicalNegativeNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.name.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalNegativeNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if self.fun != 0 { + len += 1; + } + if !self.args.is_empty() { + len += 1; + } + if self.return_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalScalarFunctionNode", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if self.fun != 0 { + let v = ScalarFunction::try_from(self.fun) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.fun)))?; + struct_ser.serialize_field("fun", &v)?; + } + if !self.args.is_empty() { + struct_ser.serialize_field("args", &self.args)?; + } + if let Some(v) = self.return_type.as_ref() { + struct_ser.serialize_field("returnType", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalNegativeNode { +impl<'de> serde::Deserialize<'de> for PhysicalScalarFunctionNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "name", + "fun", + "args", + "return_type", + "returnType", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Name, + Fun, + Args, + ReturnType, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16388,7 +19422,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalNegativeNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "name" => Ok(GeneratedField::Name), + "fun" => Ok(GeneratedField::Fun), + "args" => Ok(GeneratedField::Args), + "returnType" | "return_type" => Ok(GeneratedField::ReturnType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16398,36 +19435,60 @@ impl<'de> serde::Deserialize<'de> for PhysicalNegativeNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalNegativeNode; + type Value = PhysicalScalarFunctionNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalNegativeNode") + formatter.write_str("struct datafusion.PhysicalScalarFunctionNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut name__ = None; + let mut fun__ = None; + let mut args__ = None; + let mut return_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - expr__ = map_.next_value()?; + name__ = Some(map_.next_value()?); + } + GeneratedField::Fun => { + if fun__.is_some() { + return Err(serde::de::Error::duplicate_field("fun")); + } + fun__ = Some(map_.next_value::()? as i32); + } + GeneratedField::Args => { + if args__.is_some() { + return Err(serde::de::Error::duplicate_field("args")); + } + args__ = Some(map_.next_value()?); + } + GeneratedField::ReturnType => { + if return_type__.is_some() { + return Err(serde::de::Error::duplicate_field("returnType")); + } + return_type__ = map_.next_value()?; } } } - Ok(PhysicalNegativeNode { - expr: expr__, + Ok(PhysicalScalarFunctionNode { + name: name__.unwrap_or_default(), + fun: fun__.unwrap_or_default(), + args: args__.unwrap_or_default(), + return_type: return_type__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalNegativeNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalScalarFunctionNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalNot { +impl serde::Serialize for PhysicalScalarUdfNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -16435,29 +19496,46 @@ impl serde::Serialize for PhysicalNot { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.name.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalNot", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if !self.args.is_empty() { + len += 1; + } + if self.return_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalScalarUdfNode", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if !self.args.is_empty() { + struct_ser.serialize_field("args", &self.args)?; + } + if let Some(v) = self.return_type.as_ref() { + struct_ser.serialize_field("returnType", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalNot { +impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "name", + "args", + "return_type", + "returnType", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Name, + Args, + ReturnType, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16479,7 +19557,9 @@ impl<'de> serde::Deserialize<'de> for PhysicalNot { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "name" => Ok(GeneratedField::Name), + "args" => Ok(GeneratedField::Args), + "returnType" | "return_type" => Ok(GeneratedField::ReturnType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16489,36 +19569,52 @@ impl<'de> serde::Deserialize<'de> for PhysicalNot { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalNot; + type Value = PhysicalScalarUdfNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalNot") + formatter.write_str("struct datafusion.PhysicalScalarUdfNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut name__ = None; + let mut args__ = None; + let mut return_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - expr__ = map_.next_value()?; + name__ = Some(map_.next_value()?); + } + GeneratedField::Args => { + if args__.is_some() { + return Err(serde::de::Error::duplicate_field("args")); + } + args__ = Some(map_.next_value()?); + } + GeneratedField::ReturnType => { + if return_type__.is_some() { + return Err(serde::de::Error::duplicate_field("returnType")); + } + return_type__ = map_.next_value()?; } } } - Ok(PhysicalNot { - expr: expr__, + Ok(PhysicalScalarUdfNode { + name: name__.unwrap_or_default(), + args: args__.unwrap_or_default(), + return_type: return_type__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalNot", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalScalarUdfNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalPlanNode { +impl serde::Serialize for PhysicalSortExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -16526,148 +19622,46 @@ impl serde::Serialize for PhysicalPlanNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.physical_plan_type.is_some() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalPlanNode", len)?; - if let Some(v) = self.physical_plan_type.as_ref() { - match v { - physical_plan_node::PhysicalPlanType::ParquetScan(v) => { - struct_ser.serialize_field("parquetScan", v)?; - } - physical_plan_node::PhysicalPlanType::CsvScan(v) => { - struct_ser.serialize_field("csvScan", v)?; - } - physical_plan_node::PhysicalPlanType::Empty(v) => { - struct_ser.serialize_field("empty", v)?; - } - physical_plan_node::PhysicalPlanType::Projection(v) => { - struct_ser.serialize_field("projection", v)?; - } - physical_plan_node::PhysicalPlanType::GlobalLimit(v) => { - struct_ser.serialize_field("globalLimit", v)?; - } - physical_plan_node::PhysicalPlanType::LocalLimit(v) => { - struct_ser.serialize_field("localLimit", v)?; - } - physical_plan_node::PhysicalPlanType::Aggregate(v) => { - struct_ser.serialize_field("aggregate", v)?; - } - physical_plan_node::PhysicalPlanType::HashJoin(v) => { - struct_ser.serialize_field("hashJoin", v)?; - } - physical_plan_node::PhysicalPlanType::Sort(v) => { - struct_ser.serialize_field("sort", v)?; - } - physical_plan_node::PhysicalPlanType::CoalesceBatches(v) => { - struct_ser.serialize_field("coalesceBatches", v)?; - } - physical_plan_node::PhysicalPlanType::Filter(v) => { - struct_ser.serialize_field("filter", v)?; - } - physical_plan_node::PhysicalPlanType::Merge(v) => { - struct_ser.serialize_field("merge", v)?; - } - physical_plan_node::PhysicalPlanType::Repartition(v) => { - struct_ser.serialize_field("repartition", v)?; - } - physical_plan_node::PhysicalPlanType::Window(v) => { - struct_ser.serialize_field("window", v)?; - } - physical_plan_node::PhysicalPlanType::CrossJoin(v) => { - struct_ser.serialize_field("crossJoin", v)?; - } - physical_plan_node::PhysicalPlanType::AvroScan(v) => { - struct_ser.serialize_field("avroScan", v)?; - } - physical_plan_node::PhysicalPlanType::Extension(v) => { - struct_ser.serialize_field("extension", v)?; - } - physical_plan_node::PhysicalPlanType::Union(v) => { - struct_ser.serialize_field("union", v)?; - } - physical_plan_node::PhysicalPlanType::Explain(v) => { - struct_ser.serialize_field("explain", v)?; - } - physical_plan_node::PhysicalPlanType::SortPreservingMerge(v) => { - struct_ser.serialize_field("sortPreservingMerge", v)?; - } - physical_plan_node::PhysicalPlanType::NestedLoopJoin(v) => { - struct_ser.serialize_field("nestedLoopJoin", v)?; - } - physical_plan_node::PhysicalPlanType::Analyze(v) => { - struct_ser.serialize_field("analyze", v)?; - } - } + if self.asc { + len += 1; + } + if self.nulls_first { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalSortExprNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if self.asc { + struct_ser.serialize_field("asc", &self.asc)?; + } + if self.nulls_first { + struct_ser.serialize_field("nullsFirst", &self.nulls_first)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { +impl<'de> serde::Deserialize<'de> for PhysicalSortExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "parquet_scan", - "parquetScan", - "csv_scan", - "csvScan", - "empty", - "projection", - "global_limit", - "globalLimit", - "local_limit", - "localLimit", - "aggregate", - "hash_join", - "hashJoin", - "sort", - "coalesce_batches", - "coalesceBatches", - "filter", - "merge", - "repartition", - "window", - "cross_join", - "crossJoin", - "avro_scan", - "avroScan", - "extension", - "union", - "explain", - "sort_preserving_merge", - "sortPreservingMerge", - "nested_loop_join", - "nestedLoopJoin", - "analyze", + "expr", + "asc", + "nulls_first", + "nullsFirst", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - ParquetScan, - CsvScan, - Empty, - Projection, - GlobalLimit, - LocalLimit, - Aggregate, - HashJoin, - Sort, - CoalesceBatches, - Filter, - Merge, - Repartition, - Window, - CrossJoin, - AvroScan, - Extension, - Union, - Explain, - SortPreservingMerge, - NestedLoopJoin, - Analyze, + Expr, + Asc, + NullsFirst, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16689,28 +19683,9 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { E: serde::de::Error, { match value { - "parquetScan" | "parquet_scan" => Ok(GeneratedField::ParquetScan), - "csvScan" | "csv_scan" => Ok(GeneratedField::CsvScan), - "empty" => Ok(GeneratedField::Empty), - "projection" => Ok(GeneratedField::Projection), - "globalLimit" | "global_limit" => Ok(GeneratedField::GlobalLimit), - "localLimit" | "local_limit" => Ok(GeneratedField::LocalLimit), - "aggregate" => Ok(GeneratedField::Aggregate), - "hashJoin" | "hash_join" => Ok(GeneratedField::HashJoin), - "sort" => Ok(GeneratedField::Sort), - "coalesceBatches" | "coalesce_batches" => Ok(GeneratedField::CoalesceBatches), - "filter" => Ok(GeneratedField::Filter), - "merge" => Ok(GeneratedField::Merge), - "repartition" => Ok(GeneratedField::Repartition), - "window" => Ok(GeneratedField::Window), - "crossJoin" | "cross_join" => Ok(GeneratedField::CrossJoin), - "avroScan" | "avro_scan" => Ok(GeneratedField::AvroScan), - "extension" => Ok(GeneratedField::Extension), - "union" => Ok(GeneratedField::Union), - "explain" => Ok(GeneratedField::Explain), - "sortPreservingMerge" | "sort_preserving_merge" => Ok(GeneratedField::SortPreservingMerge), - "nestedLoopJoin" | "nested_loop_join" => Ok(GeneratedField::NestedLoopJoin), - "analyze" => Ok(GeneratedField::Analyze), + "expr" => Ok(GeneratedField::Expr), + "asc" => Ok(GeneratedField::Asc), + "nullsFirst" | "nulls_first" => Ok(GeneratedField::NullsFirst), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16720,184 +19695,52 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalPlanNode; + type Value = PhysicalSortExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalPlanNode") + formatter.write_str("struct datafusion.PhysicalSortExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut physical_plan_type__ = None; + let mut expr__ = None; + let mut asc__ = None; + let mut nulls_first__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::ParquetScan => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("parquetScan")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ParquetScan) -; - } - GeneratedField::CsvScan => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("csvScan")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CsvScan) -; - } - GeneratedField::Empty => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("empty")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Empty) -; - } - GeneratedField::Projection => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("projection")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Projection) -; - } - GeneratedField::GlobalLimit => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("globalLimit")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::GlobalLimit) -; - } - GeneratedField::LocalLimit => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("localLimit")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::LocalLimit) -; - } - GeneratedField::Aggregate => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("aggregate")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Aggregate) -; - } - GeneratedField::HashJoin => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("hashJoin")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::HashJoin) -; - } - GeneratedField::Sort => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("sort")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Sort) -; - } - GeneratedField::CoalesceBatches => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("coalesceBatches")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CoalesceBatches) -; - } - GeneratedField::Filter => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("filter")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Filter) -; - } - GeneratedField::Merge => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("merge")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Merge) -; - } - GeneratedField::Repartition => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("repartition")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Repartition) -; - } - GeneratedField::Window => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("window")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Window) -; - } - GeneratedField::CrossJoin => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("crossJoin")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CrossJoin) -; - } - GeneratedField::AvroScan => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("avroScan")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::AvroScan) -; - } - GeneratedField::Extension => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("extension")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Extension) -; - } - GeneratedField::Union => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("union")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Union) -; - } - GeneratedField::Explain => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("explain")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Explain) -; - } - GeneratedField::SortPreservingMerge => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("sortPreservingMerge")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SortPreservingMerge) -; + expr__ = map_.next_value()?; } - GeneratedField::NestedLoopJoin => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("nestedLoopJoin")); + GeneratedField::Asc => { + if asc__.is_some() { + return Err(serde::de::Error::duplicate_field("asc")); } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::NestedLoopJoin) -; + asc__ = Some(map_.next_value()?); } - GeneratedField::Analyze => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("analyze")); + GeneratedField::NullsFirst => { + if nulls_first__.is_some() { + return Err(serde::de::Error::duplicate_field("nullsFirst")); } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Analyze) -; + nulls_first__ = Some(map_.next_value()?); } } } - Ok(PhysicalPlanNode { - physical_plan_type: physical_plan_type__, + Ok(PhysicalSortExprNode { + expr: expr__, + asc: asc__.unwrap_or_default(), + nulls_first: nulls_first__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalPlanNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalSortExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalScalarFunctionNode { +impl serde::Serialize for PhysicalSortExprNodeCollection { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -16905,56 +19748,30 @@ impl serde::Serialize for PhysicalScalarFunctionNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.name.is_empty() { - len += 1; - } - if self.fun != 0 { - len += 1; - } - if !self.args.is_empty() { - len += 1; - } - if self.return_type.is_some() { + if !self.physical_sort_expr_nodes.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalScalarFunctionNode", len)?; - if !self.name.is_empty() { - struct_ser.serialize_field("name", &self.name)?; - } - if self.fun != 0 { - let v = ScalarFunction::try_from(self.fun) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.fun)))?; - struct_ser.serialize_field("fun", &v)?; - } - if !self.args.is_empty() { - struct_ser.serialize_field("args", &self.args)?; - } - if let Some(v) = self.return_type.as_ref() { - struct_ser.serialize_field("returnType", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalSortExprNodeCollection", len)?; + if !self.physical_sort_expr_nodes.is_empty() { + struct_ser.serialize_field("physicalSortExprNodes", &self.physical_sort_expr_nodes)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalScalarFunctionNode { +impl<'de> serde::Deserialize<'de> for PhysicalSortExprNodeCollection { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", - "fun", - "args", - "return_type", - "returnType", + "physical_sort_expr_nodes", + "physicalSortExprNodes", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, - Fun, - Args, - ReturnType, + PhysicalSortExprNodes, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16976,10 +19793,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarFunctionNode { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), - "fun" => Ok(GeneratedField::Fun), - "args" => Ok(GeneratedField::Args), - "returnType" | "return_type" => Ok(GeneratedField::ReturnType), + "physicalSortExprNodes" | "physical_sort_expr_nodes" => Ok(GeneratedField::PhysicalSortExprNodes), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16989,60 +19803,36 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarFunctionNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalScalarFunctionNode; + type Value = PhysicalSortExprNodeCollection; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalScalarFunctionNode") + formatter.write_str("struct datafusion.PhysicalSortExprNodeCollection") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; - let mut fun__ = None; - let mut args__ = None; - let mut return_type__ = None; + let mut physical_sort_expr_nodes__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); - } - name__ = Some(map_.next_value()?); - } - GeneratedField::Fun => { - if fun__.is_some() { - return Err(serde::de::Error::duplicate_field("fun")); - } - fun__ = Some(map_.next_value::()? as i32); - } - GeneratedField::Args => { - if args__.is_some() { - return Err(serde::de::Error::duplicate_field("args")); - } - args__ = Some(map_.next_value()?); - } - GeneratedField::ReturnType => { - if return_type__.is_some() { - return Err(serde::de::Error::duplicate_field("returnType")); + GeneratedField::PhysicalSortExprNodes => { + if physical_sort_expr_nodes__.is_some() { + return Err(serde::de::Error::duplicate_field("physicalSortExprNodes")); } - return_type__ = map_.next_value()?; + physical_sort_expr_nodes__ = Some(map_.next_value()?); } } } - Ok(PhysicalScalarFunctionNode { - name: name__.unwrap_or_default(), - fun: fun__.unwrap_or_default(), - args: args__.unwrap_or_default(), - return_type: return_type__, + Ok(PhysicalSortExprNodeCollection { + physical_sort_expr_nodes: physical_sort_expr_nodes__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalScalarFunctionNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalSortExprNodeCollection", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalScalarUdfNode { +impl serde::Serialize for PhysicalTryCastNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -17050,46 +19840,38 @@ impl serde::Serialize for PhysicalScalarUdfNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.name.is_empty() { - len += 1; - } - if !self.args.is_empty() { + if self.expr.is_some() { len += 1; } - if self.return_type.is_some() { + if self.arrow_type.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalScalarUdfNode", len)?; - if !self.name.is_empty() { - struct_ser.serialize_field("name", &self.name)?; - } - if !self.args.is_empty() { - struct_ser.serialize_field("args", &self.args)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalTryCastNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } - if let Some(v) = self.return_type.as_ref() { - struct_ser.serialize_field("returnType", v)?; + if let Some(v) = self.arrow_type.as_ref() { + struct_ser.serialize_field("arrowType", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { +impl<'de> serde::Deserialize<'de> for PhysicalTryCastNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", - "args", - "return_type", - "returnType", + "expr", + "arrow_type", + "arrowType", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, - Args, - ReturnType, + Expr, + ArrowType, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17111,9 +19893,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), - "args" => Ok(GeneratedField::Args), - "returnType" | "return_type" => Ok(GeneratedField::ReturnType), + "expr" => Ok(GeneratedField::Expr), + "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17123,52 +19904,44 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalScalarUdfNode; + type Value = PhysicalTryCastNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalScalarUdfNode") + formatter.write_str("struct datafusion.PhysicalTryCastNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; - let mut args__ = None; - let mut return_type__ = None; + let mut expr__ = None; + let mut arrow_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); - } - name__ = Some(map_.next_value()?); - } - GeneratedField::Args => { - if args__.is_some() { - return Err(serde::de::Error::duplicate_field("args")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - args__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } - GeneratedField::ReturnType => { - if return_type__.is_some() { - return Err(serde::de::Error::duplicate_field("returnType")); + GeneratedField::ArrowType => { + if arrow_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowType")); } - return_type__ = map_.next_value()?; + arrow_type__ = map_.next_value()?; } } } - Ok(PhysicalScalarUdfNode { - name: name__.unwrap_or_default(), - args: args__.unwrap_or_default(), - return_type: return_type__, + Ok(PhysicalTryCastNode { + expr: expr__, + arrow_type: arrow_type__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalScalarUdfNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalTryCastNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalSortExprNode { +impl serde::Serialize for PhysicalWhenThen { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -17176,46 +19949,39 @@ impl serde::Serialize for PhysicalSortExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { - len += 1; - } - if self.asc { + if self.when_expr.is_some() { len += 1; } - if self.nulls_first { + if self.then_expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalSortExprNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - if self.asc { - struct_ser.serialize_field("asc", &self.asc)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalWhenThen", len)?; + if let Some(v) = self.when_expr.as_ref() { + struct_ser.serialize_field("whenExpr", v)?; } - if self.nulls_first { - struct_ser.serialize_field("nullsFirst", &self.nulls_first)?; + if let Some(v) = self.then_expr.as_ref() { + struct_ser.serialize_field("thenExpr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalSortExprNode { +impl<'de> serde::Deserialize<'de> for PhysicalWhenThen { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "asc", - "nulls_first", - "nullsFirst", + "when_expr", + "whenExpr", + "then_expr", + "thenExpr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - Asc, - NullsFirst, + WhenExpr, + ThenExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17237,9 +20003,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalSortExprNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "asc" => Ok(GeneratedField::Asc), - "nullsFirst" | "nulls_first" => Ok(GeneratedField::NullsFirst), + "whenExpr" | "when_expr" => Ok(GeneratedField::WhenExpr), + "thenExpr" | "then_expr" => Ok(GeneratedField::ThenExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17249,52 +20014,44 @@ impl<'de> serde::Deserialize<'de> for PhysicalSortExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalSortExprNode; + type Value = PhysicalWhenThen; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalSortExprNode") + formatter.write_str("struct datafusion.PhysicalWhenThen") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - let mut asc__ = None; - let mut nulls_first__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - GeneratedField::Asc => { - if asc__.is_some() { - return Err(serde::de::Error::duplicate_field("asc")); + V: serde::de::MapAccess<'de>, + { + let mut when_expr__ = None; + let mut then_expr__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::WhenExpr => { + if when_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("whenExpr")); } - asc__ = Some(map_.next_value()?); + when_expr__ = map_.next_value()?; } - GeneratedField::NullsFirst => { - if nulls_first__.is_some() { - return Err(serde::de::Error::duplicate_field("nullsFirst")); + GeneratedField::ThenExpr => { + if then_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("thenExpr")); } - nulls_first__ = Some(map_.next_value()?); + then_expr__ = map_.next_value()?; } } } - Ok(PhysicalSortExprNode { - expr: expr__, - asc: asc__.unwrap_or_default(), - nulls_first: nulls_first__.unwrap_or_default(), + Ok(PhysicalWhenThen { + when_expr: when_expr__, + then_expr: then_expr__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalSortExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalWhenThen", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalSortExprNodeCollection { +impl serde::Serialize for PhysicalWindowExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -17302,30 +20059,87 @@ impl serde::Serialize for PhysicalSortExprNodeCollection { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.physical_sort_expr_nodes.is_empty() { + if !self.args.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalSortExprNodeCollection", len)?; - if !self.physical_sort_expr_nodes.is_empty() { - struct_ser.serialize_field("physicalSortExprNodes", &self.physical_sort_expr_nodes)?; + if !self.partition_by.is_empty() { + len += 1; + } + if !self.order_by.is_empty() { + len += 1; + } + if self.window_frame.is_some() { + len += 1; + } + if !self.name.is_empty() { + len += 1; + } + if self.window_function.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalWindowExprNode", len)?; + if !self.args.is_empty() { + struct_ser.serialize_field("args", &self.args)?; + } + if !self.partition_by.is_empty() { + struct_ser.serialize_field("partitionBy", &self.partition_by)?; + } + if !self.order_by.is_empty() { + struct_ser.serialize_field("orderBy", &self.order_by)?; + } + if let Some(v) = self.window_frame.as_ref() { + struct_ser.serialize_field("windowFrame", v)?; + } + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if let Some(v) = self.window_function.as_ref() { + match v { + physical_window_expr_node::WindowFunction::AggrFunction(v) => { + let v = AggregateFunction::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + struct_ser.serialize_field("aggrFunction", &v)?; + } + physical_window_expr_node::WindowFunction::BuiltInFunction(v) => { + let v = BuiltInWindowFunction::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + struct_ser.serialize_field("builtInFunction", &v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalSortExprNodeCollection { +impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "physical_sort_expr_nodes", - "physicalSortExprNodes", + "args", + "partition_by", + "partitionBy", + "order_by", + "orderBy", + "window_frame", + "windowFrame", + "name", + "aggr_function", + "aggrFunction", + "built_in_function", + "builtInFunction", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - PhysicalSortExprNodes, + Args, + PartitionBy, + OrderBy, + WindowFrame, + Name, + AggrFunction, + BuiltInFunction, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17347,7 +20161,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalSortExprNodeCollection { E: serde::de::Error, { match value { - "physicalSortExprNodes" | "physical_sort_expr_nodes" => Ok(GeneratedField::PhysicalSortExprNodes), + "args" => Ok(GeneratedField::Args), + "partitionBy" | "partition_by" => Ok(GeneratedField::PartitionBy), + "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), + "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), + "name" => Ok(GeneratedField::Name), + "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), + "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17357,36 +20177,82 @@ impl<'de> serde::Deserialize<'de> for PhysicalSortExprNodeCollection { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalSortExprNodeCollection; + type Value = PhysicalWindowExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalSortExprNodeCollection") + formatter.write_str("struct datafusion.PhysicalWindowExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut physical_sort_expr_nodes__ = None; + let mut args__ = None; + let mut partition_by__ = None; + let mut order_by__ = None; + let mut window_frame__ = None; + let mut name__ = None; + let mut window_function__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::PhysicalSortExprNodes => { - if physical_sort_expr_nodes__.is_some() { - return Err(serde::de::Error::duplicate_field("physicalSortExprNodes")); + GeneratedField::Args => { + if args__.is_some() { + return Err(serde::de::Error::duplicate_field("args")); } - physical_sort_expr_nodes__ = Some(map_.next_value()?); + args__ = Some(map_.next_value()?); + } + GeneratedField::PartitionBy => { + if partition_by__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionBy")); + } + partition_by__ = Some(map_.next_value()?); + } + GeneratedField::OrderBy => { + if order_by__.is_some() { + return Err(serde::de::Error::duplicate_field("orderBy")); + } + order_by__ = Some(map_.next_value()?); + } + GeneratedField::WindowFrame => { + if window_frame__.is_some() { + return Err(serde::de::Error::duplicate_field("windowFrame")); + } + window_frame__ = map_.next_value()?; + } + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map_.next_value()?); + } + GeneratedField::AggrFunction => { + if window_function__.is_some() { + return Err(serde::de::Error::duplicate_field("aggrFunction")); + } + window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_window_expr_node::WindowFunction::AggrFunction(x as i32)); + } + GeneratedField::BuiltInFunction => { + if window_function__.is_some() { + return Err(serde::de::Error::duplicate_field("builtInFunction")); + } + window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_window_expr_node::WindowFunction::BuiltInFunction(x as i32)); } } } - Ok(PhysicalSortExprNodeCollection { - physical_sort_expr_nodes: physical_sort_expr_nodes__.unwrap_or_default(), + Ok(PhysicalWindowExprNode { + args: args__.unwrap_or_default(), + partition_by: partition_by__.unwrap_or_default(), + order_by: order_by__.unwrap_or_default(), + window_frame: window_frame__, + name: name__.unwrap_or_default(), + window_function: window_function__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalSortExprNodeCollection", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalWindowExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalTryCastNode { +impl serde::Serialize for PlaceholderNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -17394,38 +20260,38 @@ impl serde::Serialize for PhysicalTryCastNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.id.is_empty() { len += 1; } - if self.arrow_type.is_some() { + if self.data_type.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalTryCastNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PlaceholderNode", len)?; + if !self.id.is_empty() { + struct_ser.serialize_field("id", &self.id)?; } - if let Some(v) = self.arrow_type.as_ref() { - struct_ser.serialize_field("arrowType", v)?; + if let Some(v) = self.data_type.as_ref() { + struct_ser.serialize_field("dataType", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalTryCastNode { +impl<'de> serde::Deserialize<'de> for PlaceholderNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "arrow_type", - "arrowType", + "id", + "data_type", + "dataType", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - ArrowType, + Id, + DataType, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17447,8 +20313,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalTryCastNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "id" => Ok(GeneratedField::Id), + "dataType" | "data_type" => Ok(GeneratedField::DataType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17458,84 +20324,74 @@ impl<'de> serde::Deserialize<'de> for PhysicalTryCastNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalTryCastNode; + type Value = PlaceholderNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalTryCastNode") + formatter.write_str("struct datafusion.PlaceholderNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; - let mut arrow_type__ = None; + let mut id__ = None; + let mut data_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Id => { + if id__.is_some() { + return Err(serde::de::Error::duplicate_field("id")); } - expr__ = map_.next_value()?; + id__ = Some(map_.next_value()?); } - GeneratedField::ArrowType => { - if arrow_type__.is_some() { - return Err(serde::de::Error::duplicate_field("arrowType")); + GeneratedField::DataType => { + if data_type__.is_some() { + return Err(serde::de::Error::duplicate_field("dataType")); } - arrow_type__ = map_.next_value()?; + data_type__ = map_.next_value()?; } } } - Ok(PhysicalTryCastNode { - expr: expr__, - arrow_type: arrow_type__, + Ok(PlaceholderNode { + id: id__.unwrap_or_default(), + data_type: data_type__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalTryCastNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PlaceholderNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalWhenThen { +impl serde::Serialize for PlaceholderRowExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { use serde::ser::SerializeStruct; - let mut len = 0; - if self.when_expr.is_some() { - len += 1; - } - if self.then_expr.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalWhenThen", len)?; - if let Some(v) = self.when_expr.as_ref() { - struct_ser.serialize_field("whenExpr", v)?; + let mut len = 0; + if self.schema.is_some() { + len += 1; } - if let Some(v) = self.then_expr.as_ref() { - struct_ser.serialize_field("thenExpr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PlaceholderRowExecNode", len)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalWhenThen { +impl<'de> serde::Deserialize<'de> for PlaceholderRowExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "when_expr", - "whenExpr", - "then_expr", - "thenExpr", + "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - WhenExpr, - ThenExpr, + Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17557,8 +20413,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWhenThen { E: serde::de::Error, { match value { - "whenExpr" | "when_expr" => Ok(GeneratedField::WhenExpr), - "thenExpr" | "then_expr" => Ok(GeneratedField::ThenExpr), + "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17568,44 +20423,36 @@ impl<'de> serde::Deserialize<'de> for PhysicalWhenThen { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalWhenThen; + type Value = PlaceholderRowExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalWhenThen") + formatter.write_str("struct datafusion.PlaceholderRowExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut when_expr__ = None; - let mut then_expr__ = None; + let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::WhenExpr => { - if when_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("whenExpr")); - } - when_expr__ = map_.next_value()?; - } - GeneratedField::ThenExpr => { - if then_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("thenExpr")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - then_expr__ = map_.next_value()?; + schema__ = map_.next_value()?; } } } - Ok(PhysicalWhenThen { - when_expr: when_expr__, - then_expr: then_expr__, + Ok(PlaceholderRowExecNode { + schema: schema__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalWhenThen", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PlaceholderRowExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalWindowExprNode { +impl serde::Serialize for PlanType { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -17613,87 +20460,78 @@ impl serde::Serialize for PhysicalWindowExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.args.is_empty() { - len += 1; - } - if !self.partition_by.is_empty() { - len += 1; - } - if !self.order_by.is_empty() { - len += 1; - } - if self.window_frame.is_some() { - len += 1; - } - if !self.name.is_empty() { - len += 1; - } - if self.window_function.is_some() { + if self.plan_type_enum.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalWindowExprNode", len)?; - if !self.args.is_empty() { - struct_ser.serialize_field("args", &self.args)?; - } - if !self.partition_by.is_empty() { - struct_ser.serialize_field("partitionBy", &self.partition_by)?; - } - if !self.order_by.is_empty() { - struct_ser.serialize_field("orderBy", &self.order_by)?; - } - if let Some(v) = self.window_frame.as_ref() { - struct_ser.serialize_field("windowFrame", v)?; - } - if !self.name.is_empty() { - struct_ser.serialize_field("name", &self.name)?; - } - if let Some(v) = self.window_function.as_ref() { + let mut struct_ser = serializer.serialize_struct("datafusion.PlanType", len)?; + if let Some(v) = self.plan_type_enum.as_ref() { match v { - physical_window_expr_node::WindowFunction::AggrFunction(v) => { - let v = AggregateFunction::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("aggrFunction", &v)?; + plan_type::PlanTypeEnum::InitialLogicalPlan(v) => { + struct_ser.serialize_field("InitialLogicalPlan", v)?; } - physical_window_expr_node::WindowFunction::BuiltInFunction(v) => { - let v = BuiltInWindowFunction::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("builtInFunction", &v)?; + plan_type::PlanTypeEnum::AnalyzedLogicalPlan(v) => { + struct_ser.serialize_field("AnalyzedLogicalPlan", v)?; + } + plan_type::PlanTypeEnum::FinalAnalyzedLogicalPlan(v) => { + struct_ser.serialize_field("FinalAnalyzedLogicalPlan", v)?; + } + plan_type::PlanTypeEnum::OptimizedLogicalPlan(v) => { + struct_ser.serialize_field("OptimizedLogicalPlan", v)?; + } + plan_type::PlanTypeEnum::FinalLogicalPlan(v) => { + struct_ser.serialize_field("FinalLogicalPlan", v)?; + } + plan_type::PlanTypeEnum::InitialPhysicalPlan(v) => { + struct_ser.serialize_field("InitialPhysicalPlan", v)?; + } + plan_type::PlanTypeEnum::InitialPhysicalPlanWithStats(v) => { + struct_ser.serialize_field("InitialPhysicalPlanWithStats", v)?; + } + plan_type::PlanTypeEnum::OptimizedPhysicalPlan(v) => { + struct_ser.serialize_field("OptimizedPhysicalPlan", v)?; + } + plan_type::PlanTypeEnum::FinalPhysicalPlan(v) => { + struct_ser.serialize_field("FinalPhysicalPlan", v)?; + } + plan_type::PlanTypeEnum::FinalPhysicalPlanWithStats(v) => { + struct_ser.serialize_field("FinalPhysicalPlanWithStats", v)?; } } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { +impl<'de> serde::Deserialize<'de> for PlanType { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "args", - "partition_by", - "partitionBy", - "order_by", - "orderBy", - "window_frame", - "windowFrame", - "name", - "aggr_function", - "aggrFunction", - "built_in_function", - "builtInFunction", + "InitialLogicalPlan", + "AnalyzedLogicalPlan", + "FinalAnalyzedLogicalPlan", + "OptimizedLogicalPlan", + "FinalLogicalPlan", + "InitialPhysicalPlan", + "InitialPhysicalPlanWithStats", + "OptimizedPhysicalPlan", + "FinalPhysicalPlan", + "FinalPhysicalPlanWithStats", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Args, - PartitionBy, - OrderBy, - WindowFrame, - Name, - AggrFunction, - BuiltInFunction, + InitialLogicalPlan, + AnalyzedLogicalPlan, + FinalAnalyzedLogicalPlan, + OptimizedLogicalPlan, + FinalLogicalPlan, + InitialPhysicalPlan, + InitialPhysicalPlanWithStats, + OptimizedPhysicalPlan, + FinalPhysicalPlan, + FinalPhysicalPlanWithStats, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17715,13 +20553,16 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { E: serde::de::Error, { match value { - "args" => Ok(GeneratedField::Args), - "partitionBy" | "partition_by" => Ok(GeneratedField::PartitionBy), - "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), - "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), - "name" => Ok(GeneratedField::Name), - "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), - "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), + "InitialLogicalPlan" => Ok(GeneratedField::InitialLogicalPlan), + "AnalyzedLogicalPlan" => Ok(GeneratedField::AnalyzedLogicalPlan), + "FinalAnalyzedLogicalPlan" => Ok(GeneratedField::FinalAnalyzedLogicalPlan), + "OptimizedLogicalPlan" => Ok(GeneratedField::OptimizedLogicalPlan), + "FinalLogicalPlan" => Ok(GeneratedField::FinalLogicalPlan), + "InitialPhysicalPlan" => Ok(GeneratedField::InitialPhysicalPlan), + "InitialPhysicalPlanWithStats" => Ok(GeneratedField::InitialPhysicalPlanWithStats), + "OptimizedPhysicalPlan" => Ok(GeneratedField::OptimizedPhysicalPlan), + "FinalPhysicalPlan" => Ok(GeneratedField::FinalPhysicalPlan), + "FinalPhysicalPlanWithStats" => Ok(GeneratedField::FinalPhysicalPlanWithStats), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17731,82 +20572,100 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalWindowExprNode; + type Value = PlanType; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalWindowExprNode") + formatter.write_str("struct datafusion.PlanType") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut args__ = None; - let mut partition_by__ = None; - let mut order_by__ = None; - let mut window_frame__ = None; - let mut name__ = None; - let mut window_function__ = None; + let mut plan_type_enum__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Args => { - if args__.is_some() { - return Err(serde::de::Error::duplicate_field("args")); + GeneratedField::InitialLogicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("InitialLogicalPlan")); } - args__ = Some(map_.next_value()?); + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialLogicalPlan) +; } - GeneratedField::PartitionBy => { - if partition_by__.is_some() { - return Err(serde::de::Error::duplicate_field("partitionBy")); + GeneratedField::AnalyzedLogicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("AnalyzedLogicalPlan")); } - partition_by__ = Some(map_.next_value()?); + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::AnalyzedLogicalPlan) +; } - GeneratedField::OrderBy => { - if order_by__.is_some() { - return Err(serde::de::Error::duplicate_field("orderBy")); + GeneratedField::FinalAnalyzedLogicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalAnalyzedLogicalPlan")); } - order_by__ = Some(map_.next_value()?); + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalAnalyzedLogicalPlan) +; } - GeneratedField::WindowFrame => { - if window_frame__.is_some() { - return Err(serde::de::Error::duplicate_field("windowFrame")); + GeneratedField::OptimizedLogicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("OptimizedLogicalPlan")); } - window_frame__ = map_.next_value()?; + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::OptimizedLogicalPlan) +; } - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); + GeneratedField::FinalLogicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalLogicalPlan")); } - name__ = Some(map_.next_value()?); + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalLogicalPlan) +; } - GeneratedField::AggrFunction => { - if window_function__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrFunction")); + GeneratedField::InitialPhysicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("InitialPhysicalPlan")); } - window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_window_expr_node::WindowFunction::AggrFunction(x as i32)); + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlan) +; } - GeneratedField::BuiltInFunction => { - if window_function__.is_some() { - return Err(serde::de::Error::duplicate_field("builtInFunction")); + GeneratedField::InitialPhysicalPlanWithStats => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("InitialPhysicalPlanWithStats")); } - window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_window_expr_node::WindowFunction::BuiltInFunction(x as i32)); + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlanWithStats) +; + } + GeneratedField::OptimizedPhysicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("OptimizedPhysicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::OptimizedPhysicalPlan) +; + } + GeneratedField::FinalPhysicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalPhysicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlan) +; + } + GeneratedField::FinalPhysicalPlanWithStats => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalPhysicalPlanWithStats")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlanWithStats) +; } } } - Ok(PhysicalWindowExprNode { - args: args__.unwrap_or_default(), - partition_by: partition_by__.unwrap_or_default(), - order_by: order_by__.unwrap_or_default(), - window_frame: window_frame__, - name: name__.unwrap_or_default(), - window_function: window_function__, + Ok(PlanType { + plan_type_enum: plan_type_enum__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalWindowExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PlanType", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PlaceholderNode { +impl serde::Serialize for Precision { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -17814,38 +20673,40 @@ impl serde::Serialize for PlaceholderNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.id.is_empty() { + if self.precision_info != 0 { len += 1; } - if self.data_type.is_some() { + if self.val.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PlaceholderNode", len)?; - if !self.id.is_empty() { - struct_ser.serialize_field("id", &self.id)?; + let mut struct_ser = serializer.serialize_struct("datafusion.Precision", len)?; + if self.precision_info != 0 { + let v = PrecisionInfo::try_from(self.precision_info) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.precision_info)))?; + struct_ser.serialize_field("precisionInfo", &v)?; } - if let Some(v) = self.data_type.as_ref() { - struct_ser.serialize_field("dataType", v)?; + if let Some(v) = self.val.as_ref() { + struct_ser.serialize_field("val", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PlaceholderNode { +impl<'de> serde::Deserialize<'de> for Precision { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "id", - "data_type", - "dataType", + "precision_info", + "precisionInfo", + "val", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Id, - DataType, + PrecisionInfo, + Val, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17867,8 +20728,8 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { E: serde::de::Error, { match value { - "id" => Ok(GeneratedField::Id), - "dataType" | "data_type" => Ok(GeneratedField::DataType), + "precisionInfo" | "precision_info" => Ok(GeneratedField::PrecisionInfo), + "val" => Ok(GeneratedField::Val), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17878,44 +20739,118 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PlaceholderNode; + type Value = Precision; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PlaceholderNode") + formatter.write_str("struct datafusion.Precision") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut id__ = None; - let mut data_type__ = None; + let mut precision_info__ = None; + let mut val__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Id => { - if id__.is_some() { - return Err(serde::de::Error::duplicate_field("id")); + GeneratedField::PrecisionInfo => { + if precision_info__.is_some() { + return Err(serde::de::Error::duplicate_field("precisionInfo")); } - id__ = Some(map_.next_value()?); + precision_info__ = Some(map_.next_value::()? as i32); } - GeneratedField::DataType => { - if data_type__.is_some() { - return Err(serde::de::Error::duplicate_field("dataType")); + GeneratedField::Val => { + if val__.is_some() { + return Err(serde::de::Error::duplicate_field("val")); } - data_type__ = map_.next_value()?; + val__ = map_.next_value()?; } } } - Ok(PlaceholderNode { - id: id__.unwrap_or_default(), - data_type: data_type__, + Ok(Precision { + precision_info: precision_info__.unwrap_or_default(), + val: val__, }) } } - deserializer.deserialize_struct("datafusion.PlaceholderNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.Precision", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PlanType { +impl serde::Serialize for PrecisionInfo { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Exact => "EXACT", + Self::Inexact => "INEXACT", + Self::Absent => "ABSENT", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for PrecisionInfo { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "EXACT", + "INEXACT", + "ABSENT", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PrecisionInfo; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "EXACT" => Ok(PrecisionInfo::Exact), + "INEXACT" => Ok(PrecisionInfo::Inexact), + "ABSENT" => Ok(PrecisionInfo::Absent), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for PrepareNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -17923,68 +20858,46 @@ impl serde::Serialize for PlanType { { use serde::ser::SerializeStruct; let mut len = 0; - if self.plan_type_enum.is_some() { + if !self.name.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PlanType", len)?; - if let Some(v) = self.plan_type_enum.as_ref() { - match v { - plan_type::PlanTypeEnum::InitialLogicalPlan(v) => { - struct_ser.serialize_field("InitialLogicalPlan", v)?; - } - plan_type::PlanTypeEnum::AnalyzedLogicalPlan(v) => { - struct_ser.serialize_field("AnalyzedLogicalPlan", v)?; - } - plan_type::PlanTypeEnum::FinalAnalyzedLogicalPlan(v) => { - struct_ser.serialize_field("FinalAnalyzedLogicalPlan", v)?; - } - plan_type::PlanTypeEnum::OptimizedLogicalPlan(v) => { - struct_ser.serialize_field("OptimizedLogicalPlan", v)?; - } - plan_type::PlanTypeEnum::FinalLogicalPlan(v) => { - struct_ser.serialize_field("FinalLogicalPlan", v)?; - } - plan_type::PlanTypeEnum::InitialPhysicalPlan(v) => { - struct_ser.serialize_field("InitialPhysicalPlan", v)?; - } - plan_type::PlanTypeEnum::OptimizedPhysicalPlan(v) => { - struct_ser.serialize_field("OptimizedPhysicalPlan", v)?; - } - plan_type::PlanTypeEnum::FinalPhysicalPlan(v) => { - struct_ser.serialize_field("FinalPhysicalPlan", v)?; - } - } + if !self.data_types.is_empty() { + len += 1; + } + if self.input.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PrepareNode", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if !self.data_types.is_empty() { + struct_ser.serialize_field("dataTypes", &self.data_types)?; + } + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PlanType { +impl<'de> serde::Deserialize<'de> for PrepareNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "InitialLogicalPlan", - "AnalyzedLogicalPlan", - "FinalAnalyzedLogicalPlan", - "OptimizedLogicalPlan", - "FinalLogicalPlan", - "InitialPhysicalPlan", - "OptimizedPhysicalPlan", - "FinalPhysicalPlan", + "name", + "data_types", + "dataTypes", + "input", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - InitialLogicalPlan, - AnalyzedLogicalPlan, - FinalAnalyzedLogicalPlan, - OptimizedLogicalPlan, - FinalLogicalPlan, - InitialPhysicalPlan, - OptimizedPhysicalPlan, - FinalPhysicalPlan, + Name, + DataTypes, + Input, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18006,14 +20919,9 @@ impl<'de> serde::Deserialize<'de> for PlanType { E: serde::de::Error, { match value { - "InitialLogicalPlan" => Ok(GeneratedField::InitialLogicalPlan), - "AnalyzedLogicalPlan" => Ok(GeneratedField::AnalyzedLogicalPlan), - "FinalAnalyzedLogicalPlan" => Ok(GeneratedField::FinalAnalyzedLogicalPlan), - "OptimizedLogicalPlan" => Ok(GeneratedField::OptimizedLogicalPlan), - "FinalLogicalPlan" => Ok(GeneratedField::FinalLogicalPlan), - "InitialPhysicalPlan" => Ok(GeneratedField::InitialPhysicalPlan), - "OptimizedPhysicalPlan" => Ok(GeneratedField::OptimizedPhysicalPlan), - "FinalPhysicalPlan" => Ok(GeneratedField::FinalPhysicalPlan), + "name" => Ok(GeneratedField::Name), + "dataTypes" | "data_types" => Ok(GeneratedField::DataTypes), + "input" => Ok(GeneratedField::Input), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18023,86 +20931,52 @@ impl<'de> serde::Deserialize<'de> for PlanType { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PlanType; + type Value = PrepareNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PlanType") + formatter.write_str("struct datafusion.PrepareNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut plan_type_enum__ = None; + let mut name__ = None; + let mut data_types__ = None; + let mut input__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::InitialLogicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("InitialLogicalPlan")); - } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialLogicalPlan) -; - } - GeneratedField::AnalyzedLogicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("AnalyzedLogicalPlan")); - } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::AnalyzedLogicalPlan) -; - } - GeneratedField::FinalAnalyzedLogicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("FinalAnalyzedLogicalPlan")); - } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalAnalyzedLogicalPlan) -; - } - GeneratedField::OptimizedLogicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("OptimizedLogicalPlan")); - } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::OptimizedLogicalPlan) -; - } - GeneratedField::FinalLogicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("FinalLogicalPlan")); - } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalLogicalPlan) -; - } - GeneratedField::InitialPhysicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("InitialPhysicalPlan")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlan) -; + name__ = Some(map_.next_value()?); } - GeneratedField::OptimizedPhysicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("OptimizedPhysicalPlan")); + GeneratedField::DataTypes => { + if data_types__.is_some() { + return Err(serde::de::Error::duplicate_field("dataTypes")); } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::OptimizedPhysicalPlan) -; + data_types__ = Some(map_.next_value()?); } - GeneratedField::FinalPhysicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("FinalPhysicalPlan")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlan) -; + input__ = map_.next_value()?; } } } - Ok(PlanType { - plan_type_enum: plan_type_enum__, + Ok(PrepareNode { + name: name__.unwrap_or_default(), + data_types: data_types__.unwrap_or_default(), + input: input__, }) } } - deserializer.deserialize_struct("datafusion.PlanType", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PrepareNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PrepareNode { +impl serde::Serialize for PrimaryKeyConstraint { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -18110,46 +20984,29 @@ impl serde::Serialize for PrepareNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.name.is_empty() { + if !self.indices.is_empty() { len += 1; } - if !self.data_types.is_empty() { - len += 1; - } - if self.input.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.PrepareNode", len)?; - if !self.name.is_empty() { - struct_ser.serialize_field("name", &self.name)?; - } - if !self.data_types.is_empty() { - struct_ser.serialize_field("dataTypes", &self.data_types)?; - } - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PrimaryKeyConstraint", len)?; + if !self.indices.is_empty() { + struct_ser.serialize_field("indices", &self.indices.iter().map(ToString::to_string).collect::>())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PrepareNode { +impl<'de> serde::Deserialize<'de> for PrimaryKeyConstraint { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", - "data_types", - "dataTypes", - "input", + "indices", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, - DataTypes, - Input, + Indices, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18171,9 +21028,7 @@ impl<'de> serde::Deserialize<'de> for PrepareNode { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), - "dataTypes" | "data_types" => Ok(GeneratedField::DataTypes), - "input" => Ok(GeneratedField::Input), + "indices" => Ok(GeneratedField::Indices), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18183,49 +21038,36 @@ impl<'de> serde::Deserialize<'de> for PrepareNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PrepareNode; + type Value = PrimaryKeyConstraint; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PrepareNode") + formatter.write_str("struct datafusion.PrimaryKeyConstraint") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; - let mut data_types__ = None; - let mut input__ = None; + let mut indices__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); - } - name__ = Some(map_.next_value()?); - } - GeneratedField::DataTypes => { - if data_types__.is_some() { - return Err(serde::de::Error::duplicate_field("dataTypes")); - } - data_types__ = Some(map_.next_value()?); - } - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::Indices => { + if indices__.is_some() { + return Err(serde::de::Error::duplicate_field("indices")); } - input__ = map_.next_value()?; + indices__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; } } } - Ok(PrepareNode { - name: name__.unwrap_or_default(), - data_types: data_types__.unwrap_or_default(), - input: input__, + Ok(PrimaryKeyConstraint { + indices: indices__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PrepareNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PrimaryKeyConstraint", FIELDS, GeneratedVisitor) } } impl serde::Serialize for ProjectionColumns { @@ -18841,7 +21683,206 @@ impl<'de> serde::Deserialize<'de> for RepartitionNode { deserializer.deserialize_struct("datafusion.RepartitionNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for RollupNode { +impl serde::Serialize for RollupNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.expr.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.RollupNode", len)?; + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for RollupNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "expr", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Expr, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "expr" => Ok(GeneratedField::Expr), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = RollupNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.RollupNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut expr__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = Some(map_.next_value()?); + } + } + } + Ok(RollupNode { + expr: expr__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.RollupNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for SqlOption { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.key.is_empty() { + len += 1; + } + if !self.value.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SQLOption", len)?; + if !self.key.is_empty() { + struct_ser.serialize_field("key", &self.key)?; + } + if !self.value.is_empty() { + struct_ser.serialize_field("value", &self.value)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SqlOption { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "key", + "value", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Key, + Value, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "key" => Ok(GeneratedField::Key), + "value" => Ok(GeneratedField::Value), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SqlOption; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SQLOption") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut key__ = None; + let mut value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Key => { + if key__.is_some() { + return Err(serde::de::Error::duplicate_field("key")); + } + key__ = Some(map_.next_value()?); + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = Some(map_.next_value()?); + } + } + } + Ok(SqlOption { + key: key__.unwrap_or_default(), + value: value__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SQLOption", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for SqlOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -18849,29 +21890,29 @@ impl serde::Serialize for RollupNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.expr.is_empty() { + if !self.option.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.RollupNode", len)?; - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; + let mut struct_ser = serializer.serialize_struct("datafusion.SQLOptions", len)?; + if !self.option.is_empty() { + struct_ser.serialize_field("option", &self.option)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for RollupNode { +impl<'de> serde::Deserialize<'de> for SqlOptions { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "option", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Option, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18893,7 +21934,7 @@ impl<'de> serde::Deserialize<'de> for RollupNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "option" => Ok(GeneratedField::Option), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18903,33 +21944,33 @@ impl<'de> serde::Deserialize<'de> for RollupNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = RollupNode; + type Value = SqlOptions; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.RollupNode") + formatter.write_str("struct datafusion.SQLOptions") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut option__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Option => { + if option__.is_some() { + return Err(serde::de::Error::duplicate_field("option")); } - expr__ = Some(map_.next_value()?); + option__ = Some(map_.next_value()?); } } } - Ok(RollupNode { - expr: expr__.unwrap_or_default(), + Ok(SqlOptions { + option: option__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.RollupNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SQLOptions", FIELDS, GeneratedVisitor) } } impl serde::Serialize for ScalarDictionaryValue { @@ -19279,6 +22320,18 @@ impl serde::Serialize for ScalarFunction { Self::ArrayEmpty => "ArrayEmpty", Self::ArrayPopBack => "ArrayPopBack", Self::StringToArray => "StringToArray", + Self::ToTimestampNanos => "ToTimestampNanos", + Self::ArrayIntersect => "ArrayIntersect", + Self::ArrayUnion => "ArrayUnion", + Self::OverLay => "OverLay", + Self::Range => "Range", + Self::ArrayExcept => "ArrayExcept", + Self::ArrayPopFront => "ArrayPopFront", + Self::Levenshtein => "Levenshtein", + Self::SubstrIndex => "SubstrIndex", + Self::FindInSet => "FindInSet", + Self::ArraySort => "ArraySort", + Self::ArrayDistinct => "ArrayDistinct", }; serializer.serialize_str(variant) } @@ -19408,6 +22461,18 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayEmpty", "ArrayPopBack", "StringToArray", + "ToTimestampNanos", + "ArrayIntersect", + "ArrayUnion", + "OverLay", + "Range", + "ArrayExcept", + "ArrayPopFront", + "Levenshtein", + "SubstrIndex", + "FindInSet", + "ArraySort", + "ArrayDistinct", ]; struct GeneratedVisitor; @@ -19566,6 +22631,18 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayEmpty" => Ok(ScalarFunction::ArrayEmpty), "ArrayPopBack" => Ok(ScalarFunction::ArrayPopBack), "StringToArray" => Ok(ScalarFunction::StringToArray), + "ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos), + "ArrayIntersect" => Ok(ScalarFunction::ArrayIntersect), + "ArrayUnion" => Ok(ScalarFunction::ArrayUnion), + "OverLay" => Ok(ScalarFunction::OverLay), + "Range" => Ok(ScalarFunction::Range), + "ArrayExcept" => Ok(ScalarFunction::ArrayExcept), + "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront), + "Levenshtein" => Ok(ScalarFunction::Levenshtein), + "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), + "FindInSet" => Ok(ScalarFunction::FindInSet), + "ArraySort" => Ok(ScalarFunction::ArraySort), + "ArrayDistinct" => Ok(ScalarFunction::ArrayDistinct), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -19691,24 +22768,26 @@ impl serde::Serialize for ScalarListValue { { use serde::ser::SerializeStruct; let mut len = 0; - if self.is_null { + if !self.ipc_message.is_empty() { len += 1; } - if self.field.is_some() { + if !self.arrow_data.is_empty() { len += 1; } - if !self.values.is_empty() { + if self.schema.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.ScalarListValue", len)?; - if self.is_null { - struct_ser.serialize_field("isNull", &self.is_null)?; + if !self.ipc_message.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("ipcMessage", pbjson::private::base64::encode(&self.ipc_message).as_str())?; } - if let Some(v) = self.field.as_ref() { - struct_ser.serialize_field("field", v)?; + if !self.arrow_data.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("arrowData", pbjson::private::base64::encode(&self.arrow_data).as_str())?; } - if !self.values.is_empty() { - struct_ser.serialize_field("values", &self.values)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } struct_ser.end() } @@ -19720,17 +22799,18 @@ impl<'de> serde::Deserialize<'de> for ScalarListValue { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "is_null", - "isNull", - "field", - "values", + "ipc_message", + "ipcMessage", + "arrow_data", + "arrowData", + "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - IsNull, - Field, - Values, + IpcMessage, + ArrowData, + Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -19752,9 +22832,9 @@ impl<'de> serde::Deserialize<'de> for ScalarListValue { E: serde::de::Error, { match value { - "isNull" | "is_null" => Ok(GeneratedField::IsNull), - "field" => Ok(GeneratedField::Field), - "values" => Ok(GeneratedField::Values), + "ipcMessage" | "ipc_message" => Ok(GeneratedField::IpcMessage), + "arrowData" | "arrow_data" => Ok(GeneratedField::ArrowData), + "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -19774,35 +22854,39 @@ impl<'de> serde::Deserialize<'de> for ScalarListValue { where V: serde::de::MapAccess<'de>, { - let mut is_null__ = None; - let mut field__ = None; - let mut values__ = None; + let mut ipc_message__ = None; + let mut arrow_data__ = None; + let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::IsNull => { - if is_null__.is_some() { - return Err(serde::de::Error::duplicate_field("isNull")); + GeneratedField::IpcMessage => { + if ipc_message__.is_some() { + return Err(serde::de::Error::duplicate_field("ipcMessage")); } - is_null__ = Some(map_.next_value()?); + ipc_message__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; } - GeneratedField::Field => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("field")); + GeneratedField::ArrowData => { + if arrow_data__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowData")); } - field__ = map_.next_value()?; + arrow_data__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; } - GeneratedField::Values => { - if values__.is_some() { - return Err(serde::de::Error::duplicate_field("values")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - values__ = Some(map_.next_value()?); + schema__ = map_.next_value()?; } } } Ok(ScalarListValue { - is_null: is_null__.unwrap_or_default(), - field: field__, - values: values__.unwrap_or_default(), + ipc_message: ipc_message__.unwrap_or_default(), + arrow_data: arrow_data__.unwrap_or_default(), + schema: schema__, }) } } @@ -20358,9 +23442,15 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::Time32Value(v) => { struct_ser.serialize_field("time32Value", v)?; } + scalar_value::Value::LargeListValue(v) => { + struct_ser.serialize_field("largeListValue", v)?; + } scalar_value::Value::ListValue(v) => { struct_ser.serialize_field("listValue", v)?; } + scalar_value::Value::FixedSizeListValue(v) => { + struct_ser.serialize_field("fixedSizeListValue", v)?; + } scalar_value::Value::Decimal128Value(v) => { struct_ser.serialize_field("decimal128Value", v)?; } @@ -20464,8 +23554,12 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "date32Value", "time32_value", "time32Value", + "large_list_value", + "largeListValue", "list_value", "listValue", + "fixed_size_list_value", + "fixedSizeListValue", "decimal128_value", "decimal128Value", "decimal256_value", @@ -20520,7 +23614,9 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { Float64Value, Date32Value, Time32Value, + LargeListValue, ListValue, + FixedSizeListValue, Decimal128Value, Decimal256Value, Date64Value, @@ -20575,7 +23671,9 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "float64Value" | "float64_value" => Ok(GeneratedField::Float64Value), "date32Value" | "date_32_value" => Ok(GeneratedField::Date32Value), "time32Value" | "time32_value" => Ok(GeneratedField::Time32Value), + "largeListValue" | "large_list_value" => Ok(GeneratedField::LargeListValue), "listValue" | "list_value" => Ok(GeneratedField::ListValue), + "fixedSizeListValue" | "fixed_size_list_value" => Ok(GeneratedField::FixedSizeListValue), "decimal128Value" | "decimal128_value" => Ok(GeneratedField::Decimal128Value), "decimal256Value" | "decimal256_value" => Ok(GeneratedField::Decimal256Value), "date64Value" | "date_64_value" => Ok(GeneratedField::Date64Value), @@ -20711,6 +23809,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("time32Value")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Time32Value) +; + } + GeneratedField::LargeListValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("largeListValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::LargeListValue) ; } GeneratedField::ListValue => { @@ -20718,6 +23823,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("listValue")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::ListValue) +; + } + GeneratedField::FixedSizeListValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("fixedSizeListValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeListValue) ; } GeneratedField::Decimal128Value => { @@ -21921,33 +25033,25 @@ impl serde::Serialize for Statistics { { use serde::ser::SerializeStruct; let mut len = 0; - if self.num_rows != 0 { + if self.num_rows.is_some() { len += 1; } - if self.total_byte_size != 0 { + if self.total_byte_size.is_some() { len += 1; } if !self.column_stats.is_empty() { len += 1; } - if self.is_exact { - len += 1; - } let mut struct_ser = serializer.serialize_struct("datafusion.Statistics", len)?; - if self.num_rows != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("numRows", ToString::to_string(&self.num_rows).as_str())?; + if let Some(v) = self.num_rows.as_ref() { + struct_ser.serialize_field("numRows", v)?; } - if self.total_byte_size != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("totalByteSize", ToString::to_string(&self.total_byte_size).as_str())?; + if let Some(v) = self.total_byte_size.as_ref() { + struct_ser.serialize_field("totalByteSize", v)?; } if !self.column_stats.is_empty() { struct_ser.serialize_field("columnStats", &self.column_stats)?; } - if self.is_exact { - struct_ser.serialize_field("isExact", &self.is_exact)?; - } struct_ser.end() } } @@ -21964,8 +25068,6 @@ impl<'de> serde::Deserialize<'de> for Statistics { "totalByteSize", "column_stats", "columnStats", - "is_exact", - "isExact", ]; #[allow(clippy::enum_variant_names)] @@ -21973,7 +25075,6 @@ impl<'de> serde::Deserialize<'de> for Statistics { NumRows, TotalByteSize, ColumnStats, - IsExact, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -21998,7 +25099,6 @@ impl<'de> serde::Deserialize<'de> for Statistics { "numRows" | "num_rows" => Ok(GeneratedField::NumRows), "totalByteSize" | "total_byte_size" => Ok(GeneratedField::TotalByteSize), "columnStats" | "column_stats" => Ok(GeneratedField::ColumnStats), - "isExact" | "is_exact" => Ok(GeneratedField::IsExact), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22021,24 +25121,19 @@ impl<'de> serde::Deserialize<'de> for Statistics { let mut num_rows__ = None; let mut total_byte_size__ = None; let mut column_stats__ = None; - let mut is_exact__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::NumRows => { if num_rows__.is_some() { return Err(serde::de::Error::duplicate_field("numRows")); } - num_rows__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + num_rows__ = map_.next_value()?; } GeneratedField::TotalByteSize => { if total_byte_size__.is_some() { return Err(serde::de::Error::duplicate_field("totalByteSize")); } - total_byte_size__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + total_byte_size__ = map_.next_value()?; } GeneratedField::ColumnStats => { if column_stats__.is_some() { @@ -22046,25 +25141,89 @@ impl<'de> serde::Deserialize<'de> for Statistics { } column_stats__ = Some(map_.next_value()?); } - GeneratedField::IsExact => { - if is_exact__.is_some() { - return Err(serde::de::Error::duplicate_field("isExact")); - } - is_exact__ = Some(map_.next_value()?); - } } } Ok(Statistics { - num_rows: num_rows__.unwrap_or_default(), - total_byte_size: total_byte_size__.unwrap_or_default(), + num_rows: num_rows__, + total_byte_size: total_byte_size__, column_stats: column_stats__.unwrap_or_default(), - is_exact: is_exact__.unwrap_or_default(), }) } } deserializer.deserialize_struct("datafusion.Statistics", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for StreamPartitionMode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::SinglePartition => "SINGLE_PARTITION", + Self::PartitionedExec => "PARTITIONED_EXEC", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for StreamPartitionMode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "SINGLE_PARTITION", + "PARTITIONED_EXEC", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = StreamPartitionMode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "SINGLE_PARTITION" => Ok(StreamPartitionMode::SinglePartition), + "PARTITIONED_EXEC" => Ok(StreamPartitionMode::PartitionedExec), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for StringifiedPlan { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -22344,38 +25503,146 @@ impl<'de> serde::Deserialize<'de> for StructValue { formatter.write_str("struct datafusion.StructValue") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field_values__ = None; + let mut fields__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::FieldValues => { + if field_values__.is_some() { + return Err(serde::de::Error::duplicate_field("fieldValues")); + } + field_values__ = Some(map_.next_value()?); + } + GeneratedField::Fields => { + if fields__.is_some() { + return Err(serde::de::Error::duplicate_field("fields")); + } + fields__ = Some(map_.next_value()?); + } + } + } + Ok(StructValue { + field_values: field_values__.unwrap_or_default(), + fields: fields__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.StructValue", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for SubqueryAliasNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.alias.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SubqueryAliasNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.alias.as_ref() { + struct_ser.serialize_field("alias", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "alias", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Alias, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "alias" => Ok(GeneratedField::Alias), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SubqueryAliasNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SubqueryAliasNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut field_values__ = None; - let mut fields__ = None; + let mut input__ = None; + let mut alias__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::FieldValues => { - if field_values__.is_some() { - return Err(serde::de::Error::duplicate_field("fieldValues")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - field_values__ = Some(map_.next_value()?); + input__ = map_.next_value()?; } - GeneratedField::Fields => { - if fields__.is_some() { - return Err(serde::de::Error::duplicate_field("fields")); + GeneratedField::Alias => { + if alias__.is_some() { + return Err(serde::de::Error::duplicate_field("alias")); } - fields__ = Some(map_.next_value()?); + alias__ = map_.next_value()?; } } } - Ok(StructValue { - field_values: field_values__.unwrap_or_default(), - fields: fields__.unwrap_or_default(), + Ok(SubqueryAliasNode { + input: input__, + alias: alias__, }) } } - deserializer.deserialize_struct("datafusion.StructValue", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SubqueryAliasNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for SubqueryAliasNode { +impl serde::Serialize for SymmetricHashJoinExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -22383,37 +25650,84 @@ impl serde::Serialize for SubqueryAliasNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if self.left.is_some() { len += 1; } - if self.alias.is_some() { + if self.right.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.SubqueryAliasNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + if !self.on.is_empty() { + len += 1; } - if let Some(v) = self.alias.as_ref() { - struct_ser.serialize_field("alias", v)?; + if self.join_type != 0 { + len += 1; + } + if self.partition_mode != 0 { + len += 1; + } + if self.null_equals_null { + len += 1; + } + if self.filter.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SymmetricHashJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if !self.on.is_empty() { + struct_ser.serialize_field("on", &self.on)?; + } + if self.join_type != 0 { + let v = JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if self.partition_mode != 0 { + let v = StreamPartitionMode::try_from(self.partition_mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; + struct_ser.serialize_field("partitionMode", &v)?; + } + if self.null_equals_null { + struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { +impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "alias", + "left", + "right", + "on", + "join_type", + "joinType", + "partition_mode", + "partitionMode", + "null_equals_null", + "nullEqualsNull", + "filter", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - Alias, + Left, + Right, + On, + JoinType, + PartitionMode, + NullEqualsNull, + Filter, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22435,8 +25749,13 @@ impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "alias" => Ok(GeneratedField::Alias), + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "on" => Ok(GeneratedField::On), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), + "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "filter" => Ok(GeneratedField::Filter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22446,41 +25765,81 @@ impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = SubqueryAliasNode; + type Value = SymmetricHashJoinExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.SubqueryAliasNode") + formatter.write_str("struct datafusion.SymmetricHashJoinExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut alias__ = None; + let mut left__ = None; + let mut right__ = None; + let mut on__ = None; + let mut join_type__ = None; + let mut partition_mode__ = None; + let mut null_equals_null__ = None; + let mut filter__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); } - input__ = map_.next_value()?; + left__ = map_.next_value()?; } - GeneratedField::Alias => { - if alias__.is_some() { - return Err(serde::de::Error::duplicate_field("alias")); + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); } - alias__ = map_.next_value()?; + right__ = map_.next_value()?; + } + GeneratedField::On => { + if on__.is_some() { + return Err(serde::de::Error::duplicate_field("on")); + } + on__ = Some(map_.next_value()?); + } + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::PartitionMode => { + if partition_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionMode")); + } + partition_mode__ = Some(map_.next_value::()? as i32); + } + GeneratedField::NullEqualsNull => { + if null_equals_null__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + } + null_equals_null__ = Some(map_.next_value()?); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + filter__ = map_.next_value()?; } } } - Ok(SubqueryAliasNode { - input: input__, - alias: alias__, + Ok(SymmetricHashJoinExecNode { + left: left__, + right: right__, + on: on__.unwrap_or_default(), + join_type: join_type__.unwrap_or_default(), + partition_mode: partition_mode__.unwrap_or_default(), + null_equals_null: null_equals_null__.unwrap_or_default(), + filter: filter__, }) } } - deserializer.deserialize_struct("datafusion.SubqueryAliasNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SymmetricHashJoinExecNode", FIELDS, GeneratedVisitor) } } impl serde::Serialize for TimeUnit { @@ -23065,17 +26424,108 @@ impl<'de> serde::Deserialize<'de> for UnionMode { where E: serde::de::Error, { - match value { - "sparse" => Ok(UnionMode::Sparse), - "dense" => Ok(UnionMode::Dense), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + match value { + "sparse" => Ok(UnionMode::Sparse), + "dense" => Ok(UnionMode::Dense), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for UnionNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.inputs.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.UnionNode", len)?; + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for UnionNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "inputs", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Inputs, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "inputs" => Ok(GeneratedField::Inputs), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UnionNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.UnionNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut inputs__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); + } + inputs__ = Some(map_.next_value()?); + } + } } + Ok(UnionNode { + inputs: inputs__.unwrap_or_default(), + }) } } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.UnionNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for UnionNode { +impl serde::Serialize for UniqueConstraint { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -23083,29 +26533,29 @@ impl serde::Serialize for UnionNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.inputs.is_empty() { + if !self.indices.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.UnionNode", len)?; - if !self.inputs.is_empty() { - struct_ser.serialize_field("inputs", &self.inputs)?; + let mut struct_ser = serializer.serialize_struct("datafusion.UniqueConstraint", len)?; + if !self.indices.is_empty() { + struct_ser.serialize_field("indices", &self.indices.iter().map(ToString::to_string).collect::>())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for UnionNode { +impl<'de> serde::Deserialize<'de> for UniqueConstraint { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "inputs", + "indices", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Inputs, + Indices, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -23127,7 +26577,7 @@ impl<'de> serde::Deserialize<'de> for UnionNode { E: serde::de::Error, { match value { - "inputs" => Ok(GeneratedField::Inputs), + "indices" => Ok(GeneratedField::Indices), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -23137,33 +26587,36 @@ impl<'de> serde::Deserialize<'de> for UnionNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = UnionNode; + type Value = UniqueConstraint; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.UnionNode") + formatter.write_str("struct datafusion.UniqueConstraint") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut inputs__ = None; + let mut indices__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Inputs => { - if inputs__.is_some() { - return Err(serde::de::Error::duplicate_field("inputs")); + GeneratedField::Indices => { + if indices__.is_some() { + return Err(serde::de::Error::duplicate_field("indices")); } - inputs__ = Some(map_.next_value()?); + indices__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; } } } - Ok(UnionNode { - inputs: inputs__.unwrap_or_default(), + Ok(UniqueConstraint { + indices: indices__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.UnionNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.UniqueConstraint", FIELDS, GeneratedVisitor) } } impl serde::Serialize for ValuesNode { @@ -23549,6 +27002,97 @@ impl<'de> serde::Deserialize<'de> for WhenThen { deserializer.deserialize_struct("datafusion.WhenThen", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for Wildcard { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.qualifier.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.Wildcard", len)?; + if !self.qualifier.is_empty() { + struct_ser.serialize_field("qualifier", &self.qualifier)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Wildcard { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "qualifier", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Qualifier, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "qualifier" => Ok(GeneratedField::Qualifier), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Wildcard; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.Wildcard") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut qualifier__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Qualifier => { + if qualifier__.is_some() { + return Err(serde::de::Error::duplicate_field("qualifier")); + } + qualifier__ = Some(map_.next_value()?); + } + } + } + Ok(Wildcard { + qualifier: qualifier__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.Wildcard", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for WindowAggExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -23563,13 +27107,10 @@ impl serde::Serialize for WindowAggExecNode { if !self.window_expr.is_empty() { len += 1; } - if self.input_schema.is_some() { - len += 1; - } if !self.partition_keys.is_empty() { len += 1; } - if self.partition_search_mode.is_some() { + if self.input_order_mode.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.WindowAggExecNode", len)?; @@ -23579,21 +27120,18 @@ impl serde::Serialize for WindowAggExecNode { if !self.window_expr.is_empty() { struct_ser.serialize_field("windowExpr", &self.window_expr)?; } - if let Some(v) = self.input_schema.as_ref() { - struct_ser.serialize_field("inputSchema", v)?; - } if !self.partition_keys.is_empty() { struct_ser.serialize_field("partitionKeys", &self.partition_keys)?; } - if let Some(v) = self.partition_search_mode.as_ref() { + if let Some(v) = self.input_order_mode.as_ref() { match v { - window_agg_exec_node::PartitionSearchMode::Linear(v) => { + window_agg_exec_node::InputOrderMode::Linear(v) => { struct_ser.serialize_field("linear", v)?; } - window_agg_exec_node::PartitionSearchMode::PartiallySorted(v) => { + window_agg_exec_node::InputOrderMode::PartiallySorted(v) => { struct_ser.serialize_field("partiallySorted", v)?; } - window_agg_exec_node::PartitionSearchMode::Sorted(v) => { + window_agg_exec_node::InputOrderMode::Sorted(v) => { struct_ser.serialize_field("sorted", v)?; } } @@ -23611,8 +27149,6 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { "input", "window_expr", "windowExpr", - "input_schema", - "inputSchema", "partition_keys", "partitionKeys", "linear", @@ -23625,7 +27161,6 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { enum GeneratedField { Input, WindowExpr, - InputSchema, PartitionKeys, Linear, PartiallySorted, @@ -23653,7 +27188,6 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { match value { "input" => Ok(GeneratedField::Input), "windowExpr" | "window_expr" => Ok(GeneratedField::WindowExpr), - "inputSchema" | "input_schema" => Ok(GeneratedField::InputSchema), "partitionKeys" | "partition_keys" => Ok(GeneratedField::PartitionKeys), "linear" => Ok(GeneratedField::Linear), "partiallySorted" | "partially_sorted" => Ok(GeneratedField::PartiallySorted), @@ -23679,9 +27213,8 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { { let mut input__ = None; let mut window_expr__ = None; - let mut input_schema__ = None; let mut partition_keys__ = None; - let mut partition_search_mode__ = None; + let mut input_order_mode__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -23696,12 +27229,6 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { } window_expr__ = Some(map_.next_value()?); } - GeneratedField::InputSchema => { - if input_schema__.is_some() { - return Err(serde::de::Error::duplicate_field("inputSchema")); - } - input_schema__ = map_.next_value()?; - } GeneratedField::PartitionKeys => { if partition_keys__.is_some() { return Err(serde::de::Error::duplicate_field("partitionKeys")); @@ -23709,24 +27236,24 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { partition_keys__ = Some(map_.next_value()?); } GeneratedField::Linear => { - if partition_search_mode__.is_some() { + if input_order_mode__.is_some() { return Err(serde::de::Error::duplicate_field("linear")); } - partition_search_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::PartitionSearchMode::Linear) + input_order_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::InputOrderMode::Linear) ; } GeneratedField::PartiallySorted => { - if partition_search_mode__.is_some() { + if input_order_mode__.is_some() { return Err(serde::de::Error::duplicate_field("partiallySorted")); } - partition_search_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::PartitionSearchMode::PartiallySorted) + input_order_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::InputOrderMode::PartiallySorted) ; } GeneratedField::Sorted => { - if partition_search_mode__.is_some() { + if input_order_mode__.is_some() { return Err(serde::de::Error::duplicate_field("sorted")); } - partition_search_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::PartitionSearchMode::Sorted) + input_order_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::InputOrderMode::Sorted) ; } } @@ -23734,9 +27261,8 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { Ok(WindowAggExecNode { input: input__, window_expr: window_expr__.unwrap_or_default(), - input_schema: input_schema__, partition_keys: partition_keys__.unwrap_or_default(), - partition_search_mode: partition_search_mode__, + input_order_mode: input_order_mode__, }) } } @@ -24454,3 +27980,218 @@ impl<'de> serde::Deserialize<'de> for WindowNode { deserializer.deserialize_struct("datafusion.WindowNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for WriterProperties { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.data_page_size_limit != 0 { + len += 1; + } + if self.dictionary_page_size_limit != 0 { + len += 1; + } + if self.data_page_row_count_limit != 0 { + len += 1; + } + if self.write_batch_size != 0 { + len += 1; + } + if self.max_row_group_size != 0 { + len += 1; + } + if !self.writer_version.is_empty() { + len += 1; + } + if !self.created_by.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.WriterProperties", len)?; + if self.data_page_size_limit != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dataPageSizeLimit", ToString::to_string(&self.data_page_size_limit).as_str())?; + } + if self.dictionary_page_size_limit != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dictionaryPageSizeLimit", ToString::to_string(&self.dictionary_page_size_limit).as_str())?; + } + if self.data_page_row_count_limit != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dataPageRowCountLimit", ToString::to_string(&self.data_page_row_count_limit).as_str())?; + } + if self.write_batch_size != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("writeBatchSize", ToString::to_string(&self.write_batch_size).as_str())?; + } + if self.max_row_group_size != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("maxRowGroupSize", ToString::to_string(&self.max_row_group_size).as_str())?; + } + if !self.writer_version.is_empty() { + struct_ser.serialize_field("writerVersion", &self.writer_version)?; + } + if !self.created_by.is_empty() { + struct_ser.serialize_field("createdBy", &self.created_by)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for WriterProperties { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "data_page_size_limit", + "dataPageSizeLimit", + "dictionary_page_size_limit", + "dictionaryPageSizeLimit", + "data_page_row_count_limit", + "dataPageRowCountLimit", + "write_batch_size", + "writeBatchSize", + "max_row_group_size", + "maxRowGroupSize", + "writer_version", + "writerVersion", + "created_by", + "createdBy", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + DataPageSizeLimit, + DictionaryPageSizeLimit, + DataPageRowCountLimit, + WriteBatchSize, + MaxRowGroupSize, + WriterVersion, + CreatedBy, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "dataPageSizeLimit" | "data_page_size_limit" => Ok(GeneratedField::DataPageSizeLimit), + "dictionaryPageSizeLimit" | "dictionary_page_size_limit" => Ok(GeneratedField::DictionaryPageSizeLimit), + "dataPageRowCountLimit" | "data_page_row_count_limit" => Ok(GeneratedField::DataPageRowCountLimit), + "writeBatchSize" | "write_batch_size" => Ok(GeneratedField::WriteBatchSize), + "maxRowGroupSize" | "max_row_group_size" => Ok(GeneratedField::MaxRowGroupSize), + "writerVersion" | "writer_version" => Ok(GeneratedField::WriterVersion), + "createdBy" | "created_by" => Ok(GeneratedField::CreatedBy), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = WriterProperties; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.WriterProperties") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut data_page_size_limit__ = None; + let mut dictionary_page_size_limit__ = None; + let mut data_page_row_count_limit__ = None; + let mut write_batch_size__ = None; + let mut max_row_group_size__ = None; + let mut writer_version__ = None; + let mut created_by__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::DataPageSizeLimit => { + if data_page_size_limit__.is_some() { + return Err(serde::de::Error::duplicate_field("dataPageSizeLimit")); + } + data_page_size_limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DictionaryPageSizeLimit => { + if dictionary_page_size_limit__.is_some() { + return Err(serde::de::Error::duplicate_field("dictionaryPageSizeLimit")); + } + dictionary_page_size_limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DataPageRowCountLimit => { + if data_page_row_count_limit__.is_some() { + return Err(serde::de::Error::duplicate_field("dataPageRowCountLimit")); + } + data_page_row_count_limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::WriteBatchSize => { + if write_batch_size__.is_some() { + return Err(serde::de::Error::duplicate_field("writeBatchSize")); + } + write_batch_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::MaxRowGroupSize => { + if max_row_group_size__.is_some() { + return Err(serde::de::Error::duplicate_field("maxRowGroupSize")); + } + max_row_group_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::WriterVersion => { + if writer_version__.is_some() { + return Err(serde::de::Error::duplicate_field("writerVersion")); + } + writer_version__ = Some(map_.next_value()?); + } + GeneratedField::CreatedBy => { + if created_by__.is_some() { + return Err(serde::de::Error::duplicate_field("createdBy")); + } + created_by__ = Some(map_.next_value()?); + } + } + } + Ok(WriterProperties { + data_page_size_limit: data_page_size_limit__.unwrap_or_default(), + dictionary_page_size_limit: dictionary_page_size_limit__.unwrap_or_default(), + data_page_row_count_limit: data_page_row_count_limit__.unwrap_or_default(), + write_batch_size: write_batch_size__.unwrap_or_default(), + max_row_group_size: max_row_group_size__.unwrap_or_default(), + writer_version: writer_version__.unwrap_or_default(), + created_by: created_by__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.WriterProperties", FIELDS, GeneratedVisitor) + } +} diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 3382fa17fe58..4ee0b70325ca 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -38,7 +38,7 @@ pub struct DfSchema { pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29" )] pub logical_plan_type: ::core::option::Option, } @@ -99,6 +99,10 @@ pub mod logical_plan_node { Prepare(::prost::alloc::boxed::Box), #[prost(message, tag = "27")] DropView(super::DropViewNode), + #[prost(message, tag = "28")] + DistinctOn(::prost::alloc::boxed::Box), + #[prost(message, tag = "29")] + CopyTo(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -291,6 +295,41 @@ pub struct EmptyRelationNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PrimaryKeyConstraint { + #[prost(uint64, repeated, tag = "1")] + pub indices: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UniqueConstraint { + #[prost(uint64, repeated, tag = "1")] + pub indices: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Constraint { + #[prost(oneof = "constraint::ConstraintMode", tags = "1, 2")] + pub constraint_mode: ::core::option::Option, +} +/// Nested message and enum types in `Constraint`. +pub mod constraint { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum ConstraintMode { + #[prost(message, tag = "1")] + PrimaryKey(super::PrimaryKeyConstraint), + #[prost(message, tag = "2")] + Unique(super::UniqueConstraint), + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Constraints { + #[prost(message, repeated, tag = "1")] + pub constraints: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct CreateExternalTableNode { #[prost(message, optional, tag = "12")] pub name: ::core::option::Option, @@ -321,6 +360,13 @@ pub struct CreateExternalTableNode { ::prost::alloc::string::String, ::prost::alloc::string::String, >, + #[prost(message, optional, tag = "15")] + pub constraints: ::core::option::Option, + #[prost(map = "string, message", tag = "16")] + pub column_defaults: ::std::collections::HashMap< + ::prost::alloc::string::String, + LogicalExprNode, + >, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -446,6 +492,57 @@ pub struct DistinctNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct DistinctOnNode { + #[prost(message, repeated, tag = "1")] + pub on_expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "2")] + pub select_expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "3")] + pub sort_expr: ::prost::alloc::vec::Vec, + #[prost(message, optional, boxed, tag = "4")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CopyToNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(string, tag = "2")] + pub output_url: ::prost::alloc::string::String, + #[prost(bool, tag = "3")] + pub single_file_output: bool, + #[prost(string, tag = "6")] + pub file_type: ::prost::alloc::string::String, + #[prost(oneof = "copy_to_node::CopyOptions", tags = "4, 5")] + pub copy_options: ::core::option::Option, +} +/// Nested message and enum types in `CopyToNode`. +pub mod copy_to_node { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum CopyOptions { + #[prost(message, tag = "4")] + SqlOptions(super::SqlOptions), + #[prost(message, tag = "5")] + WriterOptions(super::FileTypeWriterOptions), + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SqlOptions { + #[prost(message, repeated, tag = "1")] + pub option: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SqlOption { + #[prost(string, tag = "1")] + pub key: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub value: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, @@ -532,8 +629,8 @@ pub mod logical_expr_node { Negative(::prost::alloc::boxed::Box), #[prost(message, tag = "14")] InList(::prost::alloc::boxed::Box), - #[prost(bool, tag = "15")] - Wildcard(bool), + #[prost(message, tag = "15")] + Wildcard(super::Wildcard), #[prost(message, tag = "16")] ScalarFunction(super::ScalarFunctionNode), #[prost(message, tag = "17")] @@ -579,6 +676,12 @@ pub mod logical_expr_node { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct Wildcard { + #[prost(string, tag = "1")] + pub qualifier: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PlaceholderNode { #[prost(string, tag = "1")] pub id: ::prost::alloc::string::String, @@ -711,6 +814,8 @@ pub struct AliasNode { pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(string, tag = "2")] pub alias: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "3")] + pub relation: ::prost::alloc::vec::Vec, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -963,6 +1068,10 @@ pub struct Field { ::prost::alloc::string::String, ::prost::alloc::string::String, >, + #[prost(int64, tag = "6")] + pub dict_id: i64, + #[prost(bool, tag = "7")] + pub dict_ordered: bool, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1035,14 +1144,12 @@ pub struct Union { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarListValue { - /// encode null explicitly to distinguish a list with a null value - /// from a list with no values) - #[prost(bool, tag = "3")] - pub is_null: bool, - #[prost(message, optional, tag = "1")] - pub field: ::core::option::Option, - #[prost(message, repeated, tag = "2")] - pub values: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "1")] + pub ipc_message: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "2")] + pub arrow_data: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "3")] + pub schema: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1143,7 +1250,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" + tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" )] pub value: ::core::option::Option, } @@ -1187,9 +1294,12 @@ pub mod scalar_value { Date32Value(i32), #[prost(message, tag = "15")] Time32Value(super::ScalarTime32Value), - /// WAS: ScalarType null_list_value = 18; + #[prost(message, tag = "16")] + LargeListValue(super::ScalarListValue), #[prost(message, tag = "17")] ListValue(super::ScalarListValue), + #[prost(message, tag = "18")] + FixedSizeListValue(super::ScalarListValue), #[prost(message, tag = "20")] Decimal128Value(super::Decimal128), #[prost(message, tag = "39")] @@ -1367,7 +1477,7 @@ pub struct OptimizedPhysicalPlanType { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PlanType { - #[prost(oneof = "plan_type::PlanTypeEnum", tags = "1, 7, 8, 2, 3, 4, 5, 6")] + #[prost(oneof = "plan_type::PlanTypeEnum", tags = "1, 7, 8, 2, 3, 4, 9, 5, 6, 10")] pub plan_type_enum: ::core::option::Option, } /// Nested message and enum types in `PlanType`. @@ -1387,10 +1497,14 @@ pub mod plan_type { FinalLogicalPlan(super::EmptyMessage), #[prost(message, tag = "4")] InitialPhysicalPlan(super::EmptyMessage), + #[prost(message, tag = "9")] + InitialPhysicalPlanWithStats(super::EmptyMessage), #[prost(message, tag = "5")] OptimizedPhysicalPlan(super::OptimizedPhysicalPlanType), #[prost(message, tag = "6")] FinalPhysicalPlan(super::EmptyMessage), + #[prost(message, tag = "10")] + FinalPhysicalPlanWithStats(super::EmptyMessage), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1452,7 +1566,7 @@ pub mod owned_table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29" )] pub physical_plan_type: ::core::option::Option, } @@ -1507,10 +1621,181 @@ pub mod physical_plan_node { NestedLoopJoin(::prost::alloc::boxed::Box), #[prost(message, tag = "23")] Analyze(::prost::alloc::boxed::Box), + #[prost(message, tag = "24")] + JsonSink(::prost::alloc::boxed::Box), + #[prost(message, tag = "25")] + SymmetricHashJoin(::prost::alloc::boxed::Box), + #[prost(message, tag = "26")] + Interleave(super::InterleaveExecNode), + #[prost(message, tag = "27")] + PlaceholderRow(super::PlaceholderRowExecNode), + #[prost(message, tag = "28")] + CsvSink(::prost::alloc::boxed::Box), + #[prost(message, tag = "29")] + ParquetSink(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PartitionColumn { + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "2")] + pub arrow_type: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FileTypeWriterOptions { + #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2, 3")] + pub file_type: ::core::option::Option, +} +/// Nested message and enum types in `FileTypeWriterOptions`. +pub mod file_type_writer_options { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum FileType { + #[prost(message, tag = "1")] + JsonOptions(super::JsonWriterOptions), + #[prost(message, tag = "2")] + ParquetOptions(super::ParquetWriterOptions), + #[prost(message, tag = "3")] + CsvOptions(super::CsvWriterOptions), + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct JsonWriterOptions { + #[prost(enumeration = "CompressionTypeVariant", tag = "1")] + pub compression: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetWriterOptions { + #[prost(message, optional, tag = "1")] + pub writer_properties: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvWriterOptions { + /// Compression type + #[prost(enumeration = "CompressionTypeVariant", tag = "1")] + pub compression: i32, + /// Optional column delimiter. Defaults to `b','` + #[prost(string, tag = "2")] + pub delimiter: ::prost::alloc::string::String, + /// Whether to write column names as file headers. Defaults to `true` + #[prost(bool, tag = "3")] + pub has_header: bool, + /// Optional date format for date arrays + #[prost(string, tag = "4")] + pub date_format: ::prost::alloc::string::String, + /// Optional datetime format for datetime arrays + #[prost(string, tag = "5")] + pub datetime_format: ::prost::alloc::string::String, + /// Optional timestamp format for timestamp arrays + #[prost(string, tag = "6")] + pub timestamp_format: ::prost::alloc::string::String, + /// Optional time format for time arrays + #[prost(string, tag = "7")] + pub time_format: ::prost::alloc::string::String, + /// Optional value to represent null + #[prost(string, tag = "8")] + pub null_value: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct WriterProperties { + #[prost(uint64, tag = "1")] + pub data_page_size_limit: u64, + #[prost(uint64, tag = "2")] + pub dictionary_page_size_limit: u64, + #[prost(uint64, tag = "3")] + pub data_page_row_count_limit: u64, + #[prost(uint64, tag = "4")] + pub write_batch_size: u64, + #[prost(uint64, tag = "5")] + pub max_row_group_size: u64, + #[prost(string, tag = "6")] + pub writer_version: ::prost::alloc::string::String, + #[prost(string, tag = "7")] + pub created_by: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FileSinkConfig { + #[prost(string, tag = "1")] + pub object_store_url: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "2")] + pub file_groups: ::prost::alloc::vec::Vec, + #[prost(string, repeated, tag = "3")] + pub table_paths: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(message, optional, tag = "4")] + pub output_schema: ::core::option::Option, + #[prost(message, repeated, tag = "5")] + pub table_partition_cols: ::prost::alloc::vec::Vec, + #[prost(bool, tag = "7")] + pub single_file_output: bool, + #[prost(bool, tag = "8")] + pub overwrite: bool, + #[prost(message, optional, tag = "9")] + pub file_type_writer_options: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct JsonSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct JsonSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalExtensionNode { #[prost(bytes = "vec", tag = "1")] pub node: ::prost::alloc::vec::Vec, @@ -1779,6 +2064,8 @@ pub struct FilterExecNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] pub expr: ::core::option::Option, + #[prost(uint32, tag = "3")] + pub default_filter_selectivity: u32, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1876,6 +2163,30 @@ pub struct HashJoinExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct SymmetricHashJoinExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub left: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub right: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub on: ::prost::alloc::vec::Vec, + #[prost(enumeration = "JoinType", tag = "4")] + pub join_type: i32, + #[prost(enumeration = "StreamPartitionMode", tag = "6")] + pub partition_mode: i32, + #[prost(bool, tag = "7")] + pub null_equals_null: bool, + #[prost(message, optional, tag = "8")] + pub filter: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InterleaveExecNode { + #[prost(message, repeated, tag = "1")] + pub inputs: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionExecNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, @@ -1929,9 +2240,13 @@ pub struct JoinOn { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct EmptyExecNode { - #[prost(bool, tag = "1")] - pub produce_one_row: bool, - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "1")] + pub schema: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PlaceholderRowExecNode { + #[prost(message, optional, tag = "1")] pub schema: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1946,7 +2261,7 @@ pub struct ProjectionExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct PartiallySortedPartitionSearchMode { +pub struct PartiallySortedInputOrderMode { #[prost(uint64, repeated, tag = "6")] pub columns: ::prost::alloc::vec::Vec, } @@ -1957,26 +2272,22 @@ pub struct WindowAggExecNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "2")] pub window_expr: ::prost::alloc::vec::Vec, - #[prost(message, optional, tag = "4")] - pub input_schema: ::core::option::Option, #[prost(message, repeated, tag = "5")] pub partition_keys: ::prost::alloc::vec::Vec, /// Set optional to `None` for `BoundedWindowAggExec`. - #[prost(oneof = "window_agg_exec_node::PartitionSearchMode", tags = "7, 8, 9")] - pub partition_search_mode: ::core::option::Option< - window_agg_exec_node::PartitionSearchMode, - >, + #[prost(oneof = "window_agg_exec_node::InputOrderMode", tags = "7, 8, 9")] + pub input_order_mode: ::core::option::Option, } /// Nested message and enum types in `WindowAggExecNode`. pub mod window_agg_exec_node { /// Set optional to `None` for `BoundedWindowAggExec`. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum PartitionSearchMode { + pub enum InputOrderMode { #[prost(message, tag = "7")] Linear(super::EmptyMessage), #[prost(message, tag = "8")] - PartiallySorted(super::PartiallySortedPartitionSearchMode), + PartiallySorted(super::PartiallySortedInputOrderMode), #[prost(message, tag = "9")] Sorted(super::EmptyMessage), } @@ -2017,8 +2328,6 @@ pub struct AggregateExecNode { pub groups: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "10")] pub filter_expr: ::prost::alloc::vec::Vec, - #[prost(message, repeated, tag = "11")] - pub order_by_expr: ::prost::alloc::vec::Vec, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -2173,27 +2482,33 @@ pub struct PartitionStats { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct Precision { + #[prost(enumeration = "PrecisionInfo", tag = "1")] + pub precision_info: i32, + #[prost(message, optional, tag = "2")] + pub val: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct Statistics { - #[prost(int64, tag = "1")] - pub num_rows: i64, - #[prost(int64, tag = "2")] - pub total_byte_size: i64, + #[prost(message, optional, tag = "1")] + pub num_rows: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub total_byte_size: ::core::option::Option, #[prost(message, repeated, tag = "3")] pub column_stats: ::prost::alloc::vec::Vec, - #[prost(bool, tag = "4")] - pub is_exact: bool, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ColumnStats { #[prost(message, optional, tag = "1")] - pub min_value: ::core::option::Option, + pub min_value: ::core::option::Option, #[prost(message, optional, tag = "2")] - pub max_value: ::core::option::Option, - #[prost(uint32, tag = "3")] - pub null_count: u32, - #[prost(uint32, tag = "4")] - pub distinct_count: u32, + pub max_value: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub null_count: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub distinct_count: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -2427,6 +2742,18 @@ pub enum ScalarFunction { ArrayEmpty = 115, ArrayPopBack = 116, StringToArray = 117, + ToTimestampNanos = 118, + ArrayIntersect = 119, + ArrayUnion = 120, + OverLay = 121, + Range = 122, + ArrayExcept = 123, + ArrayPopFront = 124, + Levenshtein = 125, + SubstrIndex = 126, + FindInSet = 127, + ArraySort = 128, + ArrayDistinct = 129, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2553,6 +2880,18 @@ impl ScalarFunction { ScalarFunction::ArrayEmpty => "ArrayEmpty", ScalarFunction::ArrayPopBack => "ArrayPopBack", ScalarFunction::StringToArray => "StringToArray", + ScalarFunction::ToTimestampNanos => "ToTimestampNanos", + ScalarFunction::ArrayIntersect => "ArrayIntersect", + ScalarFunction::ArrayUnion => "ArrayUnion", + ScalarFunction::OverLay => "OverLay", + ScalarFunction::Range => "Range", + ScalarFunction::ArrayExcept => "ArrayExcept", + ScalarFunction::ArrayPopFront => "ArrayPopFront", + ScalarFunction::Levenshtein => "Levenshtein", + ScalarFunction::SubstrIndex => "SubstrIndex", + ScalarFunction::FindInSet => "FindInSet", + ScalarFunction::ArraySort => "ArraySort", + ScalarFunction::ArrayDistinct => "ArrayDistinct", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2676,6 +3015,18 @@ impl ScalarFunction { "ArrayEmpty" => Some(Self::ArrayEmpty), "ArrayPopBack" => Some(Self::ArrayPopBack), "StringToArray" => Some(Self::StringToArray), + "ToTimestampNanos" => Some(Self::ToTimestampNanos), + "ArrayIntersect" => Some(Self::ArrayIntersect), + "ArrayUnion" => Some(Self::ArrayUnion), + "OverLay" => Some(Self::OverLay), + "Range" => Some(Self::Range), + "ArrayExcept" => Some(Self::ArrayExcept), + "ArrayPopFront" => Some(Self::ArrayPopFront), + "Levenshtein" => Some(Self::Levenshtein), + "SubstrIndex" => Some(Self::SubstrIndex), + "FindInSet" => Some(Self::FindInSet), + "ArraySort" => Some(Self::ArraySort), + "ArrayDistinct" => Some(Self::ArrayDistinct), _ => None, } } @@ -2720,6 +3071,7 @@ pub enum AggregateFunction { RegrSxx = 32, RegrSyy = 33, RegrSxy = 34, + StringAgg = 35, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2765,6 +3117,7 @@ impl AggregateFunction { AggregateFunction::RegrSxx => "REGR_SXX", AggregateFunction::RegrSyy => "REGR_SYY", AggregateFunction::RegrSxy => "REGR_SXY", + AggregateFunction::StringAgg => "STRING_AGG", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2807,6 +3160,7 @@ impl AggregateFunction { "REGR_SXX" => Some(Self::RegrSxx), "REGR_SYY" => Some(Self::RegrSyy), "REGR_SXY" => Some(Self::RegrSxy), + "STRING_AGG" => Some(Self::StringAgg), _ => None, } } @@ -3037,6 +3391,41 @@ impl UnionMode { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum CompressionTypeVariant { + Gzip = 0, + Bzip2 = 1, + Xz = 2, + Zstd = 3, + Uncompressed = 4, +} +impl CompressionTypeVariant { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + CompressionTypeVariant::Gzip => "GZIP", + CompressionTypeVariant::Bzip2 => "BZIP2", + CompressionTypeVariant::Xz => "XZ", + CompressionTypeVariant::Zstd => "ZSTD", + CompressionTypeVariant::Uncompressed => "UNCOMPRESSED", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "GZIP" => Some(Self::Gzip), + "BZIP2" => Some(Self::Bzip2), + "XZ" => Some(Self::Xz), + "ZSTD" => Some(Self::Zstd), + "UNCOMPRESSED" => Some(Self::Uncompressed), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum PartitionMode { CollectLeft = 0, Partitioned = 1, @@ -3066,6 +3455,32 @@ impl PartitionMode { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum StreamPartitionMode { + SinglePartition = 0, + PartitionedExec = 1, +} +impl StreamPartitionMode { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + StreamPartitionMode::SinglePartition => "SINGLE_PARTITION", + StreamPartitionMode::PartitionedExec => "PARTITIONED_EXEC", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SINGLE_PARTITION" => Some(Self::SinglePartition), + "PARTITIONED_EXEC" => Some(Self::PartitionedExec), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum AggregateMode { Partial = 0, Final = 1, @@ -3125,3 +3540,32 @@ impl JoinSide { } } } +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum PrecisionInfo { + Exact = 0, + Inexact = 1, + Absent = 2, +} +impl PrecisionInfo { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + PrecisionInfo::Exact => "EXACT", + PrecisionInfo::Inexact => "INEXACT", + PrecisionInfo::Absent => "ABSENT", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "EXACT" => Some(Self::Exact), + "INEXACT" => Some(Self::Inexact), + "ABSENT" => Some(Self::Absent), + _ => None, + } + } +} diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index f8746ef4fd6c..36c5b44f00b9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -19,39 +19,48 @@ use crate::protobuf::{ self, plan_type::PlanTypeEnum::{ AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, - FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan, OptimizedLogicalPlan, + FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, + InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, OptimizedPhysicalPlan, }, AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; -use arrow::datatypes::{ - i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, - UnionFields, UnionMode, +use arrow::{ + buffer::Buffer, + datatypes::{ + i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, + UnionFields, UnionMode, + }, + ipc::{reader::read_record_batch, root_as_message}, }; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ - internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, - OwnedTableReference, Result, ScalarValue, + arrow_datafusion_err, internal_err, plan_datafusion_err, Column, Constraint, + Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, + Result, ScalarValue, }; +use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - abs, acos, acosh, array, array_append, array_concat, array_dims, array_element, - array_has, array_has_all, array_has_any, array_length, array_ndims, array_position, - array_positions, array_prepend, array_remove, array_remove_all, array_remove_n, - array_repeat, array_replace, array_replace_all, array_replace_n, array_slice, - array_to_string, ascii, asin, asinh, atan, atan2, atanh, bit_length, btrim, - cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, + abs, acos, acosh, array, array_append, array_concat, array_dims, array_distinct, + array_element, array_except, array_has, array_has_all, array_has_any, + array_intersect, array_length, array_ndims, array_position, array_positions, + array_prepend, array_remove, array_remove_all, array_remove_n, array_repeat, + array_replace, array_replace_all, array_replace_n, array_slice, array_sort, + array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh, bit_length, + btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, - date_trunc, degrees, digest, exp, + date_trunc, decode, degrees, digest, encode, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, floor, from_unixtime, gcd, isnan, iszero, lcm, left, ln, log, log10, log2, + factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, + lcm, left, levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, pi, power, radians, - random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, - rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, - starts_with, strpos, substr, substring, tan, tanh, to_hex, to_timestamp_micros, - to_timestamp_millis, to_timestamp_seconds, translate, trim, trunc, upper, uuid, - window_frame::regularize, + lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, + radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, + round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, + sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substr_index, + substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, + to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, @@ -59,7 +68,7 @@ use datafusion_expr::{ WindowFrameUnits, }; use datafusion_expr::{ - array_empty, array_pop_back, + array_empty, array_pop_back, array_pop_front, expr::{Alias, Placeholder}, }; use std::sync::Arc; @@ -369,8 +378,20 @@ impl TryFrom<&protobuf::Field> for Field { type Error = Error; fn try_from(field: &protobuf::Field) -> Result { let datatype = field.arrow_type.as_deref().required("arrow_type")?; - Ok(Self::new(field.name.as_str(), datatype, field.nullable) - .with_metadata(field.metadata.clone())) + let field = if field.dict_id != 0 { + Self::new_dict( + field.name.as_str(), + datatype, + field.nullable, + field.dict_id, + field.dict_ordered, + ) + .with_metadata(field.metadata.clone()) + } else { + Self::new(field.name.as_str(), datatype, field.nullable) + .with_metadata(field.metadata.clone()) + }; + Ok(field) } } @@ -400,12 +421,14 @@ impl From<&protobuf::StringifiedPlan> for StringifiedPlan { } FinalLogicalPlan(_) => PlanType::FinalLogicalPlan, InitialPhysicalPlan(_) => PlanType::InitialPhysicalPlan, + InitialPhysicalPlanWithStats(_) => PlanType::InitialPhysicalPlanWithStats, OptimizedPhysicalPlan(OptimizedPhysicalPlanType { optimizer_name }) => { PlanType::OptimizedPhysicalPlan { optimizer_name: optimizer_name.clone(), } } FinalPhysicalPlan(_) => PlanType::FinalPhysicalPlan, + FinalPhysicalPlanWithStats(_) => PlanType::FinalPhysicalPlanWithStats, }, plan: Arc::new(stringified_plan.plan.clone()), } @@ -454,16 +477,20 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Rtrim => Self::Rtrim, ScalarFunction::ToTimestamp => Self::ToTimestamp, ScalarFunction::ArrayAppend => Self::ArrayAppend, + ScalarFunction::ArraySort => Self::ArraySort, ScalarFunction::ArrayConcat => Self::ArrayConcat, ScalarFunction::ArrayEmpty => Self::ArrayEmpty, + ScalarFunction::ArrayExcept => Self::ArrayExcept, ScalarFunction::ArrayHasAll => Self::ArrayHasAll, ScalarFunction::ArrayHasAny => Self::ArrayHasAny, ScalarFunction::ArrayHas => Self::ArrayHas, ScalarFunction::ArrayDims => Self::ArrayDims, + ScalarFunction::ArrayDistinct => Self::ArrayDistinct, ScalarFunction::ArrayElement => Self::ArrayElement, ScalarFunction::Flatten => Self::Flatten, ScalarFunction::ArrayLength => Self::ArrayLength, ScalarFunction::ArrayNdims => Self::ArrayNdims, + ScalarFunction::ArrayPopFront => Self::ArrayPopFront, ScalarFunction::ArrayPopBack => Self::ArrayPopBack, ScalarFunction::ArrayPosition => Self::ArrayPosition, ScalarFunction::ArrayPositions => Self::ArrayPositions, @@ -477,6 +504,9 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, ScalarFunction::ArraySlice => Self::ArraySlice, ScalarFunction::ArrayToString => Self::ArrayToString, + ScalarFunction::ArrayIntersect => Self::ArrayIntersect, + ScalarFunction::ArrayUnion => Self::ArrayUnion, + ScalarFunction::Range => Self::Range, ScalarFunction::Cardinality => Self::Cardinality, ScalarFunction::Array => Self::MakeArray, ScalarFunction::NullIf => Self::NullIf, @@ -517,6 +547,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Substr => Self::Substr, ScalarFunction::ToHex => Self::ToHex, ScalarFunction::ToTimestampMicros => Self::ToTimestampMicros, + ScalarFunction::ToTimestampNanos => Self::ToTimestampNanos, ScalarFunction::ToTimestampSeconds => Self::ToTimestampSeconds, ScalarFunction::Now => Self::Now, ScalarFunction::CurrentDate => Self::CurrentDate, @@ -534,6 +565,10 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Isnan => Self::Isnan, ScalarFunction::Iszero => Self::Iszero, ScalarFunction::ArrowTypeof => Self::ArrowTypeof, + ScalarFunction::OverLay => Self::OverLay, + ScalarFunction::Levenshtein => Self::Levenshtein, + ScalarFunction::SubstrIndex => Self::SubstrIndex, + ScalarFunction::FindInSet => Self::FindInSet, } } } @@ -580,6 +615,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Median => Self::Median, protobuf::AggregateFunction::FirstValueAgg => Self::FirstValue, protobuf::AggregateFunction::LastValueAgg => Self::LastValue, + protobuf::AggregateFunction::StringAgg => Self::StringAgg, } } } @@ -641,25 +677,56 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::Float32Value(v) => Self::Float32(Some(*v)), Value::Float64Value(v) => Self::Float64(Some(*v)), Value::Date32Value(v) => Self::Date32(Some(*v)), - Value::ListValue(scalar_list) => { + // ScalarValue::List is serialized using arrow IPC format + Value::ListValue(scalar_list) + | Value::FixedSizeListValue(scalar_list) + | Value::LargeListValue(scalar_list) => { let protobuf::ScalarListValue { - is_null, - values, - field, + ipc_message, + arrow_data, + schema, } = &scalar_list; - let field: Field = field.as_ref().required("field")?; - let field = Arc::new(field); - - let values: Result, Error> = - values.iter().map(|val| val.try_into()).collect(); - let values = values?; + let schema: Schema = if let Some(schema_ref) = schema { + schema_ref.try_into()? + } else { + return Err(Error::General( + "Invalid schema while deserializing ScalarValue::List" + .to_string(), + )); + }; - validate_list_values(field.as_ref(), &values)?; + let message = root_as_message(ipc_message.as_slice()).map_err(|e| { + Error::General(format!( + "Error IPC message while deserializing ScalarValue::List: {e}" + )) + })?; + let buffer = Buffer::from(arrow_data); - let values = if *is_null { None } else { Some(values) }; + let ipc_batch = message.header_as_record_batch().ok_or_else(|| { + Error::General( + "Unexpected message type deserializing ScalarValue::List" + .to_string(), + ) + })?; - Self::List(values, field) + let record_batch = read_record_batch( + &buffer, + ipc_batch, + Arc::new(schema), + &Default::default(), + None, + &message.version(), + ) + .map_err(|e| arrow_datafusion_err!(e)) + .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; + let arr = record_batch.column(0); + match value { + Value::ListValue(_) => Self::List(arr.to_owned()), + Value::LargeListValue(_) => Self::LargeList(arr.to_owned()), + Value::FixedSizeListValue(_) => Self::FixedSizeList(arr.to_owned()), + _ => unreachable!(), + } } Value::NullValue(v) => { let null_type: DataType = v.try_into()?; @@ -880,6 +947,33 @@ impl From for JoinConstraint { } } +impl From for Constraints { + fn from(constraints: protobuf::Constraints) -> Self { + Constraints::new_unverified( + constraints + .constraints + .into_iter() + .map(|item| item.into()) + .collect(), + ) + } +} + +impl From for Constraint { + fn from(value: protobuf::Constraint) -> Self { + match value.constraint_mode.unwrap() { + protobuf::constraint::ConstraintMode::PrimaryKey(elem) => { + Constraint::PrimaryKey( + elem.indices.into_iter().map(|item| item as usize).collect(), + ) + } + protobuf::constraint::ConstraintMode::Unique(elem) => Constraint::Unique( + elem.indices.into_iter().map(|item| item as usize).collect(), + ), + } + } +} + pub fn parse_i32_to_time_unit(value: &i32) -> Result { protobuf::TimeUnit::try_from(*value) .map(|t| t.into()) @@ -898,22 +992,6 @@ pub fn parse_i32_to_aggregate_function(value: &i32) -> Result Result<(), Error> { - for value in values { - let field_type = field.data_type(); - let value_type = value.data_type(); - - if field_type != &value_type { - return Err(proto_error(format!( - "Expected field type {field_type:?}, got scalar of type: {value_type:?}" - ))); - } - } - Ok(()) -} - pub fn parse_expr( proto: &protobuf::LogicalExprNode, registry: &dyn FunctionRegistry, @@ -1008,7 +1086,7 @@ pub fn parse_expr( .iter() .map(|e| parse_expr(e, registry)) .collect::, _>>()?; - let order_by = expr + let mut order_by = expr .order_by .iter() .map(|e| parse_expr(e, registry)) @@ -1018,7 +1096,8 @@ pub fn parse_expr( .as_ref() .map::, _>(|window_frame| { let window_frame = window_frame.clone().try_into()?; - regularize(window_frame, order_by.len()) + check_window_frame(&window_frame, order_by.len()) + .map(|_| window_frame) }) .transpose()? .ok_or_else(|| { @@ -1026,13 +1105,14 @@ pub fn parse_expr( "missing window frame during deserialization".to_string(), ) })?; + regularize_window_order_by(&window_frame, &mut order_by)?; match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { let aggr_function = parse_i32_to_aggregate_function(i)?; Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::AggregateFunction( + datafusion_expr::expr::WindowFunctionDefinition::AggregateFunction( aggr_function, ), vec![parse_required_expr(expr.expr.as_deref(), registry, "expr")?], @@ -1051,7 +1131,7 @@ pub fn parse_expr( .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::BuiltInWindowFunction( + datafusion_expr::expr::WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, ), args, @@ -1066,7 +1146,7 @@ pub fn parse_expr( .map(|e| vec![e]) .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::AggregateUDF( + datafusion_expr::expr::WindowFunctionDefinition::AggregateUDF( udaf_function, ), args, @@ -1081,7 +1161,7 @@ pub fn parse_expr( .map(|e| vec![e]) .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::WindowUDF( + datafusion_expr::expr::WindowFunctionDefinition::WindowUDF( udwf_function, ), args, @@ -1108,6 +1188,11 @@ pub fn parse_expr( } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( parse_required_expr(alias.expr.as_deref(), registry, "expr")?, + alias + .relation + .first() + .map(|r| OwnedTableReference::try_from(r.clone())) + .transpose()?, alias.alias.clone(), ))), ExprType::IsNullExpr(is_null) => Ok(Expr::IsNull(Box::new(parse_required_expr( @@ -1253,7 +1338,13 @@ pub fn parse_expr( .collect::, _>>()?, in_list.negated, ))), - ExprType::Wildcard(_) => Ok(Expr::Wildcard), + ExprType::Wildcard(protobuf::Wildcard { qualifier }) => Ok(Expr::Wildcard { + qualifier: if qualifier.is_empty() { + None + } else { + Some(qualifier.clone()) + }, + }), ExprType::ScalarFunction(expr) => { let scalar_function = protobuf::ScalarFunction::try_from(expr.fun) .map_err(|_| Error::unknown("ScalarFunction", expr.fun))?; @@ -1274,6 +1365,14 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::ArraySort => Ok(array_sort( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), + ScalarFunction::ArrayPopFront => { + Ok(array_pop_front(parse_expr(&args[0], registry)?)) + } ScalarFunction::ArrayPopBack => { Ok(array_pop_back(parse_expr(&args[0], registry)?)) } @@ -1287,6 +1386,10 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), + ScalarFunction::ArrayExcept => Ok(array_except( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ArrayHasAll => Ok(array_has_all( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1299,6 +1402,10 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::ArrayIntersect => Ok(array_intersect( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ArrayPosition => Ok(array_position( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1350,6 +1457,12 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::Range => Ok(gen_range( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), ScalarFunction::Cardinality => { Ok(cardinality(parse_expr(&args[0], registry)?)) } @@ -1360,6 +1473,9 @@ pub fn parse_expr( ScalarFunction::ArrayDims => { Ok(array_dims(parse_expr(&args[0], registry)?)) } + ScalarFunction::ArrayDistinct => { + Ok(array_distinct(parse_expr(&args[0], registry)?)) + } ScalarFunction::ArrayElement => Ok(array_element( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1370,6 +1486,12 @@ pub fn parse_expr( ScalarFunction::ArrayNdims => { Ok(array_ndims(parse_expr(&args[0], registry)?)) } + ScalarFunction::ArrayUnion => Ok(array( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry)?)), ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry)?)), ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry)?)), @@ -1431,6 +1553,14 @@ pub fn parse_expr( ScalarFunction::Sha384 => Ok(sha384(parse_expr(&args[0], registry)?)), ScalarFunction::Sha512 => Ok(sha512(parse_expr(&args[0], registry)?)), ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], registry)?)), + ScalarFunction::Encode => Ok(encode( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), + ScalarFunction::Decode => Ok(decode( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::NullIf => Ok(nullif( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1546,6 +1676,10 @@ pub fn parse_expr( )) } } + ScalarFunction::Levenshtein => Ok(levenshtein( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], registry)?)), ScalarFunction::ToTimestampMillis => { Ok(to_timestamp_millis(parse_expr(&args[0], registry)?)) @@ -1553,6 +1687,9 @@ pub fn parse_expr( ScalarFunction::ToTimestampMicros => { Ok(to_timestamp_micros(parse_expr(&args[0], registry)?)) } + ScalarFunction::ToTimestampNanos => { + Ok(to_timestamp_nanos(parse_expr(&args[0], registry)?)) + } ScalarFunction::ToTimestampSeconds => { Ok(to_timestamp_seconds(parse_expr(&args[0], registry)?)) } @@ -1593,14 +1730,41 @@ pub fn parse_expr( )), ScalarFunction::Isnan => Ok(isnan(parse_expr(&args[0], registry)?)), ScalarFunction::Iszero => Ok(iszero(parse_expr(&args[0], registry)?)), - _ => Err(proto_error( - "Protobuf deserialization error: Unsupported scalar function", + ScalarFunction::ArrowTypeof => { + Ok(arrow_typeof(parse_expr(&args[0], registry)?)) + } + ScalarFunction::ToTimestamp => { + Ok(to_timestamp_seconds(parse_expr(&args[0], registry)?)) + } + ScalarFunction::Flatten => Ok(flatten(parse_expr(&args[0], registry)?)), + ScalarFunction::StringToArray => Ok(string_to_array( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), + ScalarFunction::OverLay => Ok(overlay( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, )), + ScalarFunction::SubstrIndex => Ok(substr_index( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), + ScalarFunction::FindInSet => Ok(find_in_set( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), + ScalarFunction::StructFun => { + Ok(struct_fun(parse_expr(&args[0], registry)?)) + } } } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args }) => { let scalar_fn = registry.udf(fun_name.as_str())?; - Ok(Expr::ScalarUDF(expr::ScalarUDF::new( + Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, args.iter() .map(|expr| parse_expr(expr, registry)) @@ -1610,12 +1774,13 @@ pub fn parse_expr( ExprType::AggregateUdfExpr(pb) => { let agg_fn = registry.udaf(pb.fun_name.as_str())?; - Ok(Expr::AggregateUDF(expr::AggregateUDF::new( + Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, pb.args .iter() .map(|expr| parse_expr(expr, registry)) .collect::, Error>>()?, + false, parse_optional_expr(pb.filter.as_deref(), registry)?.map(Box::new), parse_vec_expr(&pb.order_by, registry)?, ))) @@ -1713,9 +1878,7 @@ fn parse_vec_expr( ) -> Result>, Error> { let res = p .iter() - .map(|elem| { - parse_expr(elem, registry).map_err(|e| DataFusionError::Plan(e.to_string())) - }) + .map(|elem| parse_expr(elem, registry).map_err(|e| plan_datafusion_err!("{}", e))) .collect::>>()?; // Convert empty vector to None. Ok((!res.is_empty()).then_some(res)) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 099352cb589e..e8a38784481b 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -15,9 +15,18 @@ // specific language governing permissions and limitations // under the License. +use arrow::csv::WriterBuilder; +use std::collections::HashMap; +use std::fmt::Debug; +use std::str::FromStr; +use std::sync::Arc; + use crate::common::{byte_to_string, proto_error, str_to_byte}; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; -use crate::protobuf::{CustomTableScanNode, LogicalExprNodeCollection}; +use crate::protobuf::{ + copy_to_node, file_type_writer_options, CustomTableScanNode, + LogicalExprNodeCollection, SqlOption, +}; use crate::{ convert_required, protobuf::{ @@ -25,12 +34,13 @@ use crate::{ logical_plan_node::LogicalPlanType, LogicalExtensionNode, LogicalPlanNode, }, }; + use arrow::datatypes::{DataType, Schema, SchemaRef}; +#[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::{ datasource::{ - file_format::{ - avro::AvroFormat, csv::CsvFormat, parquet::ParquetFormat, FileFormat, - }, + file_format::{avro::AvroFormat, csv::CsvFormat, FileFormat}, listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, view::ViewTable, TableProvider, @@ -38,40 +48,41 @@ use datafusion::{ datasource::{provider_as_source, source_as_provider}, prelude::SessionContext, }; -use datafusion_common::not_impl_err; use datafusion_common::{ - context, internal_err, parsers::CompressionTypeVariant, DataFusionError, - OwnedTableReference, Result, + context, file_options::StatementOptions, internal_err, not_impl_err, + parsers::CompressionTypeVariant, plan_datafusion_err, DataFusionError, FileType, + FileTypeWriterOptions, OwnedTableReference, Result, }; -use datafusion_expr::logical_plan::DdlStatement; -use datafusion_expr::DropView; use datafusion_expr::{ + dml, logical_plan::{ builder::project, Aggregate, CreateCatalog, CreateCatalogSchema, - CreateExternalTable, CreateView, CrossJoin, Distinct, EmptyRelation, Extension, - Join, JoinConstraint, Limit, Prepare, Projection, Repartition, Sort, - SubqueryAlias, TableScan, Values, Window, + CreateExternalTable, CreateView, CrossJoin, DdlStatement, Distinct, + EmptyRelation, Extension, Join, JoinConstraint, Limit, Prepare, Projection, + Repartition, Sort, SubqueryAlias, TableScan, Values, Window, }, - Expr, LogicalPlan, LogicalPlanBuilder, + DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, }; + +use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; +use datafusion_expr::dml::CopyOptions; use prost::bytes::BufMut; use prost::Message; -use std::fmt::Debug; -use std::str::FromStr; -use std::sync::Arc; pub mod from_proto; pub mod to_proto; impl From for DataFusionError { fn from(e: from_proto::Error) -> Self { - DataFusionError::Plan(e.to_string()) + plan_datafusion_err!("{}", e) } } impl From for DataFusionError { fn from(e: to_proto::Error) -> Self { - DataFusionError::Plan(e.to_string()) + plan_datafusion_err!("{}", e) } } @@ -251,7 +262,7 @@ impl AsLogicalPlan for LogicalPlanNode { Some(a) => match a { protobuf::projection_node::OptionalAlias::Alias(alias) => { Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( - new_proj, + Arc::new(new_proj), alias.clone(), )?)) } @@ -335,6 +346,7 @@ impl AsLogicalPlan for LogicalPlanNode { "logical_plan::from_proto() Unsupported file format '{self:?}'" )) })? { + #[cfg(feature = "parquet")] &FileFormatType::Parquet(protobuf::ParquetFormat {}) => { Arc::new(ParquetFormat::default()) } @@ -362,7 +374,7 @@ impl AsLogicalPlan for LogicalPlanNode { .collect::, _>>()?; let options = ListingOptions::new(file_format) - .with_file_extension(scan.file_extension.clone()) + .with_file_extension(&scan.file_extension) .with_table_partition_cols( scan.table_partition_cols .iter() @@ -457,7 +469,7 @@ impl AsLogicalPlan for LogicalPlanNode { let input: LogicalPlan = into_logical_plan!(repartition.input, ctx, extension_codec)?; use protobuf::repartition_node::PartitionMethod; - let pb_partition_method = repartition.partition_method.clone().ok_or_else(|| { + let pb_partition_method = repartition.partition_method.as_ref().ok_or_else(|| { DataFusionError::Internal(String::from( "Protobuf deserialization error, RepartitionNode was missing required field 'partition_method'", )) @@ -472,10 +484,10 @@ impl AsLogicalPlan for LogicalPlanNode { .iter() .map(|expr| from_proto::parse_expr(expr, ctx)) .collect::, _>>()?, - partition_count as usize, + *partition_count as usize, ), PartitionMethod::RoundRobin(partition_count) => { - Partitioning::RoundRobinBatch(partition_count as usize) + Partitioning::RoundRobinBatch(*partition_count as usize) } }; @@ -493,6 +505,11 @@ impl AsLogicalPlan for LogicalPlanNode { )) })?; + let constraints = (create_extern_table.constraints.clone()).ok_or_else(|| { + DataFusionError::Internal(String::from( + "Protobuf deserialization error, CreateExternalTableNode was missing required table constraints.", + )) + })?; let definition = if !create_extern_table.definition.is_empty() { Some(create_extern_table.definition.clone()) } else { @@ -514,6 +531,13 @@ impl AsLogicalPlan for LogicalPlanNode { order_exprs.push(order_expr) } + let mut column_defaults = + HashMap::with_capacity(create_extern_table.column_defaults.len()); + for (col_name, expr) in &create_extern_table.column_defaults { + let expr = from_proto::parse_expr(expr, ctx)?; + column_defaults.insert(col_name.clone(), expr); + } + Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable(CreateExternalTable { schema: pb_schema.try_into()?, name: from_owned_table_reference(create_extern_table.name.as_ref(), "CreateExternalTable")?, @@ -532,6 +556,8 @@ impl AsLogicalPlan for LogicalPlanNode { definition, unbounded: create_extern_table.unbounded, options: create_extern_table.options.clone(), + constraints: constraints.into(), + column_defaults, }))) } LogicalPlanType::CreateView(create_view) => { @@ -726,6 +752,33 @@ impl AsLogicalPlan for LogicalPlanNode { into_logical_plan!(distinct.input, ctx, extension_codec)?; LogicalPlanBuilder::from(input).distinct()?.build() } + LogicalPlanType::DistinctOn(distinct_on) => { + let input: LogicalPlan = + into_logical_plan!(distinct_on.input, ctx, extension_codec)?; + let on_expr = distinct_on + .on_expr + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?; + let select_expr = distinct_on + .select_expr + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?; + let sort_expr = match distinct_on.sort_expr.len() { + 0 => None, + _ => Some( + distinct_on + .sort_expr + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?, + ), + }; + LogicalPlanBuilder::from(input) + .distinct_on(on_expr, select_expr, sort_expr)? + .build() + } LogicalPlanType::ViewScan(scan) => { let schema: Schema = convert_required!(scan.schema)?; @@ -779,6 +832,79 @@ impl AsLogicalPlan for LogicalPlanNode { schema: Arc::new(convert_required!(dropview.schema)?), }), )), + LogicalPlanType::CopyTo(copy) => { + let input: LogicalPlan = + into_logical_plan!(copy.input, ctx, extension_codec)?; + + let copy_options = match ©.copy_options { + Some(copy_to_node::CopyOptions::SqlOptions(opt)) => { + let options = opt + .option + .iter() + .map(|o| (o.key.clone(), o.value.clone())) + .collect(); + CopyOptions::SQLOptions(StatementOptions::from(&options)) + } + Some(copy_to_node::CopyOptions::WriterOptions(opt)) => { + match &opt.file_type { + Some(ft) => match ft { + file_type_writer_options::FileType::CsvOptions( + writer_options, + ) => { + let writer_builder = + csv_writer_options_from_proto(writer_options)?; + CopyOptions::WriterOptions(Box::new( + FileTypeWriterOptions::CSV( + CsvWriterOptions::new( + writer_builder, + CompressionTypeVariant::UNCOMPRESSED, + ), + ), + )) + } + file_type_writer_options::FileType::ParquetOptions( + writer_options, + ) => { + let writer_properties = + match &writer_options.writer_properties { + Some(serialized_writer_options) => { + writer_properties_from_proto( + serialized_writer_options, + )? + } + _ => WriterProperties::default(), + }; + CopyOptions::WriterOptions(Box::new( + FileTypeWriterOptions::Parquet( + ParquetWriterOptions::new(writer_properties), + ), + )) + } + _ => { + return Err(proto_error( + "WriterOptions unsupported file_type", + )) + } + }, + None => { + return Err(proto_error( + "WriterOptions missing file_type", + )) + } + } + } + None => return Err(proto_error("CopyTo missing CopyOptions")), + }; + Ok(datafusion_expr::LogicalPlan::Copy( + datafusion_expr::dml::CopyTo { + input: Arc::new(input), + output_url: copy.output_url.clone(), + file_format: FileType::from_str(©.file_type)?, + single_file_output: copy.single_file_output, + copy_options, + }, + )) + } } } @@ -842,28 +968,49 @@ impl AsLogicalPlan for LogicalPlanNode { if let Some(listing_table) = source.downcast_ref::() { let any = listing_table.options().format.as_any(); - let file_format_type = if any.is::() { - FileFormatType::Parquet(protobuf::ParquetFormat {}) - } else if let Some(csv) = any.downcast_ref::() { - FileFormatType::Csv(protobuf::CsvFormat { - delimiter: byte_to_string(csv.delimiter(), "delimiter")?, - has_header: csv.has_header(), - quote: byte_to_string(csv.quote(), "quote")?, - optional_escape: if let Some(escape) = csv.escape() { - Some(protobuf::csv_format::OptionalEscape::Escape( - byte_to_string(escape, "escape")?, - )) - } else { - None - }, - }) - } else if any.is::() { - FileFormatType::Avro(protobuf::AvroFormat {}) - } else { - return Err(proto_error(format!( + let file_format_type = { + let mut maybe_some_type = None; + + #[cfg(feature = "parquet")] + if any.is::() { + maybe_some_type = + Some(FileFormatType::Parquet(protobuf::ParquetFormat {})) + }; + + if let Some(csv) = any.downcast_ref::() { + maybe_some_type = + Some(FileFormatType::Csv(protobuf::CsvFormat { + delimiter: byte_to_string( + csv.delimiter(), + "delimiter", + )?, + has_header: csv.has_header(), + quote: byte_to_string(csv.quote(), "quote")?, + optional_escape: if let Some(escape) = csv.escape() { + Some( + protobuf::csv_format::OptionalEscape::Escape( + byte_to_string(escape, "escape")?, + ), + ) + } else { + None + }, + })) + } + + if any.is::() { + maybe_some_type = + Some(FileFormatType::Avro(protobuf::AvroFormat {})) + } + + if let Some(file_format_type) = maybe_some_type { + file_format_type + } else { + return Err(proto_error(format!( "Error converting file format, {:?} is invalid as a datafusion format.", listing_table.options().format ))); + } }; let options = listing_table.options(); @@ -976,7 +1123,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Distinct(Distinct { input }) => { + LogicalPlan::Distinct(Distinct::All(input)) => { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( input.as_ref(), @@ -990,6 +1137,42 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + .. + })) => { + let input: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + let sort_expr = match sort_expr { + None => vec![], + Some(sort_expr) => sort_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + }; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::DistinctOn(Box::new( + protobuf::DistinctOnNode { + on_expr: on_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + select_expr: select_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + sort_expr, + input: Some(Box::new(input)), + }, + ))), + }) + } LogicalPlan::Window(Window { input, window_expr, .. }) => { @@ -1205,6 +1388,8 @@ impl AsLogicalPlan for LogicalPlanNode { order_exprs, unbounded, options, + constraints, + column_defaults, }, )) => { let mut converted_order_exprs: Vec = vec![]; @@ -1219,6 +1404,12 @@ impl AsLogicalPlan for LogicalPlanNode { converted_order_exprs.push(temp); } + let mut converted_column_defaults = + HashMap::with_capacity(column_defaults.len()); + for (col_name, expr) in column_defaults { + converted_column_defaults.insert(col_name.clone(), expr.try_into()?); + } + Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateExternalTable( protobuf::CreateExternalTableNode { @@ -1235,6 +1426,8 @@ impl AsLogicalPlan for LogicalPlanNode { file_compression_type: file_compression_type.to_string(), unbounded: *unbounded, options: options.clone(), + constraints: Some(constraints.clone().into()), + column_defaults: converted_column_defaults, }, )), }) @@ -1423,12 +1616,163 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::Dml(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for Dml", )), - LogicalPlan::Copy(_) => Err(proto_error( - "LogicalPlan serde is not yet implemented for Copy", - )), + LogicalPlan::Copy(dml::CopyTo { + input, + output_url, + single_file_output, + file_format, + copy_options, + }) => { + let input = protobuf::LogicalPlanNode::try_from_logical_plan( + input, + extension_codec, + )?; + + let copy_options_proto: Option = + match copy_options { + CopyOptions::SQLOptions(opt) => { + let options: Vec = opt + .clone() + .into_inner() + .iter() + .map(|(k, v)| SqlOption { + key: k.to_string(), + value: v.to_string(), + }) + .collect(); + Some(copy_to_node::CopyOptions::SqlOptions( + protobuf::SqlOptions { option: options }, + )) + } + CopyOptions::WriterOptions(opt) => { + match opt.as_ref() { + FileTypeWriterOptions::CSV(csv_opts) => { + let csv_options = &csv_opts.writer_options; + let csv_writer_options = csv_writer_options_to_proto( + csv_options, + &csv_opts.compression, + ); + let csv_options = + file_type_writer_options::FileType::CsvOptions( + csv_writer_options, + ); + Some(copy_to_node::CopyOptions::WriterOptions( + protobuf::FileTypeWriterOptions { + file_type: Some(csv_options), + }, + )) + } + FileTypeWriterOptions::Parquet(parquet_opts) => { + let parquet_writer_options = + protobuf::ParquetWriterOptions { + writer_properties: Some( + writer_properties_to_proto( + &parquet_opts.writer_options, + ), + ), + }; + let parquet_options = file_type_writer_options::FileType::ParquetOptions(parquet_writer_options); + Some(copy_to_node::CopyOptions::WriterOptions( + protobuf::FileTypeWriterOptions { + file_type: Some(parquet_options), + }, + )) + } + _ => { + return Err(proto_error( + "Unsupported FileTypeWriterOptions in CopyTo", + )) + } + } + } + }; + + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::CopyTo(Box::new( + protobuf::CopyToNode { + input: Some(Box::new(input)), + single_file_output: *single_file_output, + output_url: output_url.to_string(), + file_type: file_format.to_string(), + copy_options: copy_options_proto, + }, + ))), + }) + } LogicalPlan::DescribeTable(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for DescribeTable", )), } } } + +pub(crate) fn csv_writer_options_to_proto( + csv_options: &WriterBuilder, + compression: &CompressionTypeVariant, +) -> protobuf::CsvWriterOptions { + let compression: protobuf::CompressionTypeVariant = compression.into(); + protobuf::CsvWriterOptions { + compression: compression.into(), + delimiter: (csv_options.delimiter() as char).to_string(), + has_header: csv_options.header(), + date_format: csv_options.date_format().unwrap_or("").to_owned(), + datetime_format: csv_options.datetime_format().unwrap_or("").to_owned(), + timestamp_format: csv_options.timestamp_format().unwrap_or("").to_owned(), + time_format: csv_options.time_format().unwrap_or("").to_owned(), + null_value: csv_options.null().to_owned(), + } +} + +pub(crate) fn csv_writer_options_from_proto( + writer_options: &protobuf::CsvWriterOptions, +) -> Result { + let mut builder = WriterBuilder::new(); + if !writer_options.delimiter.is_empty() { + if let Some(delimiter) = writer_options.delimiter.chars().next() { + if delimiter.is_ascii() { + builder = builder.with_delimiter(delimiter as u8); + } else { + return Err(proto_error("CSV Delimiter is not ASCII")); + } + } else { + return Err(proto_error("Error parsing CSV Delimiter")); + } + } + Ok(builder + .with_header(writer_options.has_header) + .with_date_format(writer_options.date_format.clone()) + .with_datetime_format(writer_options.datetime_format.clone()) + .with_timestamp_format(writer_options.timestamp_format.clone()) + .with_time_format(writer_options.time_format.clone()) + .with_null(writer_options.null_value.clone())) +} + +pub(crate) fn writer_properties_to_proto( + props: &WriterProperties, +) -> protobuf::WriterProperties { + protobuf::WriterProperties { + data_page_size_limit: props.data_page_size_limit() as u64, + dictionary_page_size_limit: props.dictionary_page_size_limit() as u64, + data_page_row_count_limit: props.data_page_row_count_limit() as u64, + write_batch_size: props.write_batch_size() as u64, + max_row_group_size: props.max_row_group_size() as u64, + writer_version: format!("{:?}", props.writer_version()), + created_by: props.created_by().to_string(), + } +} + +pub(crate) fn writer_properties_from_proto( + props: &protobuf::WriterProperties, +) -> Result { + let writer_version = + WriterVersion::from_str(&props.writer_version).map_err(proto_error)?; + Ok(WriterProperties::builder() + .set_created_by(props.created_by.clone()) + .set_writer_version(writer_version) + .set_dictionary_page_size_limit(props.dictionary_page_size_limit as usize) + .set_data_page_row_count_limit(props.data_page_row_count_limit as usize) + .set_data_page_size_limit(props.data_page_size_limit as usize) + .set_write_batch_size(props.write_batch_size as usize) + .set_max_row_group_size(props.max_row_group_size as usize) + .build()) +} diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 8a8550d05d13..a162b2389cd1 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -24,40 +24,40 @@ use crate::protobuf::{ arrow_type::ArrowTypeEnum, plan_type::PlanTypeEnum::{ AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, - FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan, OptimizedLogicalPlan, + FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, + InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, OptimizedPhysicalPlan, }, AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; -use arrow::datatypes::{ - DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, - UnionMode, +use arrow::{ + datatypes::{ + DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, + TimeUnit, UnionMode, + }, + ipc::writer::{DictionaryTracker, IpcDataGenerator}, + record_batch::RecordBatch, }; use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, OwnedTableReference, ScalarValue, + Column, Constraint, Constraints, DFField, DFSchema, DFSchemaRef, OwnedTableReference, + ScalarValue, }; use datafusion_expr::expr::{ - self, Alias, Between, BinaryExpr, Cast, GetFieldAccess, GetIndexedField, GroupingSet, - InList, Like, Placeholder, ScalarFunction, ScalarUDF, Sort, + self, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Cast, GetFieldAccess, + GetIndexedField, GroupingSet, InList, Like, Placeholder, ScalarFunction, + ScalarFunctionDefinition, Sort, }; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, Expr, JoinConstraint, JoinType, - TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; #[derive(Debug)] pub enum Error { General(String), - InconsistentListTyping(DataType, DataType), - - InconsistentListDesignated { - value: ScalarValue, - designated: DataType, - }, - InvalidScalarValue(ScalarValue), InvalidScalarType(DataType), @@ -75,18 +75,6 @@ impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { Self::General(desc) => write!(f, "General error: {desc}"), - Self::InconsistentListTyping(type1, type2) => { - write!( - f, - "Lists with inconsistent typing; {type1:?} and {type2:?} found within list", - ) - } - Self::InconsistentListDesignated { value, designated } => { - write!( - f, - "Value {value:?} was inconsistent with designated type {designated:?}" - ) - } Self::InvalidScalarValue(value) => { write!(f, "{value:?} is invalid as a DataFusion scalar value") } @@ -120,6 +108,8 @@ impl TryFrom<&Field> for protobuf::Field { nullable: field.is_nullable(), children: Vec::new(), metadata: field.metadata().clone(), + dict_id: field.dict_id().unwrap_or(0), + dict_ordered: field.dict_is_ordered().unwrap_or(false), }) } } @@ -366,6 +356,12 @@ impl From<&StringifiedPlan> for protobuf::StringifiedPlan { PlanType::FinalPhysicalPlan => Some(protobuf::PlanType { plan_type_enum: Some(FinalPhysicalPlan(EmptyMessage {})), }), + PlanType::InitialPhysicalPlanWithStats => Some(protobuf::PlanType { + plan_type_enum: Some(InitialPhysicalPlanWithStats(EmptyMessage {})), + }), + PlanType::FinalPhysicalPlanWithStats => Some(protobuf::PlanType { + plan_type_enum: Some(FinalPhysicalPlanWithStats(EmptyMessage {})), + }), }, plan: stringified_plan.plan.to_string(), } @@ -412,6 +408,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Median => Self::Median, AggregateFunction::FirstValue => Self::FirstValueAgg, AggregateFunction::LastValue => Self::LastValueAgg, + AggregateFunction::StringAgg => Self::StringAgg, } } } @@ -490,9 +487,17 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { Expr::Column(c) => Self { expr_type: Some(ExprType::Column(c.into())), }, - Expr::Alias(Alias { expr, name, .. }) => { + Expr::Alias(Alias { + expr, + relation, + name, + }) => { let alias = Box::new(protobuf::AliasNode { expr: Some(Box::new(expr.as_ref().try_into()?)), + relation: relation + .to_owned() + .map(|r| vec![r.into()]) + .unwrap_or(vec![]), alias: name.to_owned(), }); Self { @@ -600,24 +605,24 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { ref window_frame, }) => { let window_function = match fun { - WindowFunction::AggregateFunction(fun) => { + WindowFunctionDefinition::AggregateFunction(fun) => { protobuf::window_expr_node::WindowFunction::AggrFunction( protobuf::AggregateFunction::from(fun).into(), ) } - WindowFunction::BuiltInWindowFunction(fun) => { + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { protobuf::window_expr_node::WindowFunction::BuiltInFunction( protobuf::BuiltInWindowFunction::from(fun).into(), ) } - WindowFunction::AggregateUDF(aggr_udf) => { + WindowFunctionDefinition::AggregateUDF(aggr_udf) => { protobuf::window_expr_node::WindowFunction::Udaf( - aggr_udf.name.clone(), + aggr_udf.name().to_string(), ) } - WindowFunction::WindowUDF(window_udf) => { + WindowFunctionDefinition::WindowUDF(window_udf) => { protobuf::window_expr_node::WindowFunction::Udwf( - window_udf.name.clone(), + window_udf.name().to_string(), ) } }; @@ -650,159 +655,178 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } } Expr::AggregateFunction(expr::AggregateFunction { - ref fun, + ref func_def, ref args, ref distinct, ref filter, ref order_by, }) => { - let aggr_function = match fun { - AggregateFunction::ApproxDistinct => { - protobuf::AggregateFunction::ApproxDistinct - } - AggregateFunction::ApproxPercentileCont => { - protobuf::AggregateFunction::ApproxPercentileCont - } - AggregateFunction::ApproxPercentileContWithWeight => { - protobuf::AggregateFunction::ApproxPercentileContWithWeight - } - AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, - AggregateFunction::Min => protobuf::AggregateFunction::Min, - AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::Sum => protobuf::AggregateFunction::Sum, - AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, - AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, - AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, - AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, - AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, - AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Count => protobuf::AggregateFunction::Count, - AggregateFunction::Variance => protobuf::AggregateFunction::Variance, - AggregateFunction::VariancePop => { - protobuf::AggregateFunction::VariancePop - } - AggregateFunction::Covariance => { - protobuf::AggregateFunction::Covariance - } - AggregateFunction::CovariancePop => { - protobuf::AggregateFunction::CovariancePop - } - AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, - AggregateFunction::StddevPop => { - protobuf::AggregateFunction::StddevPop - } - AggregateFunction::Correlation => { - protobuf::AggregateFunction::Correlation - } - AggregateFunction::RegrSlope => { - protobuf::AggregateFunction::RegrSlope - } - AggregateFunction::RegrIntercept => { - protobuf::AggregateFunction::RegrIntercept - } - AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, - AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, - AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, - AggregateFunction::RegrCount => { - protobuf::AggregateFunction::RegrCount - } - AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, - AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, - AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, - AggregateFunction::ApproxMedian => { - protobuf::AggregateFunction::ApproxMedian - } - AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, - AggregateFunction::Median => protobuf::AggregateFunction::Median, - AggregateFunction::FirstValue => { - protobuf::AggregateFunction::FirstValueAgg - } - AggregateFunction::LastValue => { - protobuf::AggregateFunction::LastValueAgg + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let aggr_function = match fun { + AggregateFunction::ApproxDistinct => { + protobuf::AggregateFunction::ApproxDistinct + } + AggregateFunction::ApproxPercentileCont => { + protobuf::AggregateFunction::ApproxPercentileCont + } + AggregateFunction::ApproxPercentileContWithWeight => { + protobuf::AggregateFunction::ApproxPercentileContWithWeight + } + AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, + AggregateFunction::Min => protobuf::AggregateFunction::Min, + AggregateFunction::Max => protobuf::AggregateFunction::Max, + AggregateFunction::Sum => protobuf::AggregateFunction::Sum, + AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, + AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, + AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, + AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, + AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, + AggregateFunction::Avg => protobuf::AggregateFunction::Avg, + AggregateFunction::Count => protobuf::AggregateFunction::Count, + AggregateFunction::Variance => protobuf::AggregateFunction::Variance, + AggregateFunction::VariancePop => { + protobuf::AggregateFunction::VariancePop + } + AggregateFunction::Covariance => { + protobuf::AggregateFunction::Covariance + } + AggregateFunction::CovariancePop => { + protobuf::AggregateFunction::CovariancePop + } + AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, + AggregateFunction::StddevPop => { + protobuf::AggregateFunction::StddevPop + } + AggregateFunction::Correlation => { + protobuf::AggregateFunction::Correlation + } + AggregateFunction::RegrSlope => { + protobuf::AggregateFunction::RegrSlope + } + AggregateFunction::RegrIntercept => { + protobuf::AggregateFunction::RegrIntercept + } + AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, + AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, + AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, + AggregateFunction::RegrCount => { + protobuf::AggregateFunction::RegrCount + } + AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, + AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, + AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, + AggregateFunction::ApproxMedian => { + protobuf::AggregateFunction::ApproxMedian + } + AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, + AggregateFunction::Median => protobuf::AggregateFunction::Median, + AggregateFunction::FirstValue => { + protobuf::AggregateFunction::FirstValueAgg + } + AggregateFunction::LastValue => { + protobuf::AggregateFunction::LastValueAgg + } + AggregateFunction::StringAgg => { + protobuf::AggregateFunction::StringAgg + } + }; + + let aggregate_expr = protobuf::AggregateExprNode { + aggr_function: aggr_function.into(), + expr: args + .iter() + .map(|v| v.try_into()) + .collect::, _>>()?, + distinct: *distinct, + filter: match filter { + Some(e) => Some(Box::new(e.as_ref().try_into()?)), + None => None, + }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + None => vec![], + }, + }; + Self { + expr_type: Some(ExprType::AggregateExpr(Box::new( + aggregate_expr, + ))), + } } - }; - - let aggregate_expr = protobuf::AggregateExprNode { - aggr_function: aggr_function.into(), - expr: args - .iter() - .map(|v| v.try_into()) - .collect::, _>>()?, - distinct: *distinct, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], + AggregateFunctionDefinition::UDF(fun) => Self { + expr_type: Some(ExprType::AggregateUdfExpr(Box::new( + protobuf::AggregateUdfExprNode { + fun_name: fun.name().to_string(), + args: args + .iter() + .map(|expr| expr.try_into()) + .collect::, Error>>()?, + filter: match filter { + Some(e) => Some(Box::new(e.as_ref().try_into()?)), + None => None, + }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + None => vec![], + }, + }, + ))), }, - }; - Self { - expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), + AggregateFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( + "Proto serialization error: Trying to serialize a unresolved function" + .to_string(), + )); + } } } + Expr::ScalarVariable(_, _) => { return Err(Error::General( "Proto serialization error: Scalar Variable not supported" .to_string(), )) } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let fun: protobuf::ScalarFunction = fun.try_into()?; - let args: Vec = args + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let args = args .iter() - .map(|e| e.try_into()) - .collect::, Error>>()?; - Self { - expr_type: Some(ExprType::ScalarFunction( - protobuf::ScalarFunctionNode { - fun: fun.into(), - args, - }, - )), + .map(|expr| expr.try_into()) + .collect::, Error>>()?; + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let fun: protobuf::ScalarFunction = fun.try_into()?; + Self { + expr_type: Some(ExprType::ScalarFunction( + protobuf::ScalarFunctionNode { + fun: fun.into(), + args, + }, + )), + } + } + ScalarFunctionDefinition::UDF(fun) => Self { + expr_type: Some(ExprType::ScalarUdfExpr( + protobuf::ScalarUdfExprNode { + fun_name: fun.name().to_string(), + args, + }, + )), + }, + ScalarFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( + "Proto serialization error: Trying to serialize a unresolved function" + .to_string(), + )); + } } } - Expr::ScalarUDF(ScalarUDF { fun, args }) => Self { - expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { - fun_name: fun.name.clone(), - args: args - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - })), - }, - Expr::AggregateUDF(expr::AggregateUDF { - fun, - args, - filter, - order_by, - }) => Self { - expr_type: Some(ExprType::AggregateUdfExpr(Box::new( - protobuf::AggregateUdfExprNode { - fun_name: fun.name.clone(), - args: args.iter().map(|expr| expr.try_into()).collect::, - Error, - >>( - )?, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], - }, - }, - ))), - }, Expr::Not(expr) => { let expr = Box::new(protobuf::Not { expr: Some(Box::new(expr.as_ref().try_into()?)), @@ -974,8 +998,10 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { expr_type: Some(ExprType::InList(expr)), } } - Expr::Wildcard => Self { - expr_type: Some(ExprType::Wildcard(true)), + Expr::Wildcard { qualifier } => Self { + expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { + qualifier: qualifier.clone().unwrap_or("".to_string()), + })), }, Expr::ScalarSubquery(_) | Expr::InSubquery(_) @@ -1066,11 +1092,6 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { })), } } - - Expr::QualifiedWildcard { .. } => return Err(Error::General( - "Proto serialization error: Expr::QualifiedWildcard { .. } not supported" - .to_string(), - )), }; Ok(expr_node) @@ -1136,33 +1157,56 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { Value::LargeUtf8Value(s.to_owned()) }) } - ScalarValue::Fixedsizelist(..) => Err(Error::General( - "Proto serialization error: ScalarValue::Fixedsizelist not supported" - .to_string(), - )), - ScalarValue::List(values, boxed_field) => { - let is_null = values.is_none(); + // ScalarValue::List and ScalarValue::FixedSizeList are serialized using + // Arrow IPC messages as a single column RecordBatch + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { + // Wrap in a "field_name" column + let batch = RecordBatch::try_from_iter(vec![( + "field_name", + arr.to_owned(), + )]) + .map_err(|e| { + Error::General( format!("Error creating temporary batch while encoding ScalarValue::List: {e}")) + })?; + + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + let (_, encoded_message) = gen + .encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .map_err(|e| { + Error::General(format!( + "Error encoding ScalarValue::List as IPC: {e}" + )) + })?; - let values = if let Some(values) = values.as_ref() { - values - .iter() - .map(|v| v.try_into()) - .collect::, _>>()? - } else { - vec![] - }; + let schema: protobuf::Schema = batch.schema().try_into()?; - let field = boxed_field.as_ref().try_into()?; + let scalar_list_value = protobuf::ScalarListValue { + ipc_message: encoded_message.ipc_message, + arrow_data: encoded_message.arrow_data, + schema: Some(schema), + }; - Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::ListValue( - protobuf::ScalarListValue { - is_null, - field: Some(field), - values, - }, - )), - }) + match val { + ScalarValue::List(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::ListValue( + scalar_list_value, + )), + }), + ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::LargeListValue( + scalar_list_value, + )), + }), + ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::FixedSizeListValue( + scalar_list_value, + )), + }), + _ => unreachable!(), + } } ScalarValue::Date32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s)) @@ -1460,16 +1504,20 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Rtrim => Self::Rtrim, BuiltinScalarFunction::ToTimestamp => Self::ToTimestamp, BuiltinScalarFunction::ArrayAppend => Self::ArrayAppend, + BuiltinScalarFunction::ArraySort => Self::ArraySort, BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat, BuiltinScalarFunction::ArrayEmpty => Self::ArrayEmpty, + BuiltinScalarFunction::ArrayExcept => Self::ArrayExcept, BuiltinScalarFunction::ArrayHasAll => Self::ArrayHasAll, BuiltinScalarFunction::ArrayHasAny => Self::ArrayHasAny, BuiltinScalarFunction::ArrayHas => Self::ArrayHas, BuiltinScalarFunction::ArrayDims => Self::ArrayDims, + BuiltinScalarFunction::ArrayDistinct => Self::ArrayDistinct, BuiltinScalarFunction::ArrayElement => Self::ArrayElement, BuiltinScalarFunction::Flatten => Self::Flatten, BuiltinScalarFunction::ArrayLength => Self::ArrayLength, BuiltinScalarFunction::ArrayNdims => Self::ArrayNdims, + BuiltinScalarFunction::ArrayPopFront => Self::ArrayPopFront, BuiltinScalarFunction::ArrayPopBack => Self::ArrayPopBack, BuiltinScalarFunction::ArrayPosition => Self::ArrayPosition, BuiltinScalarFunction::ArrayPositions => Self::ArrayPositions, @@ -1483,6 +1531,9 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, BuiltinScalarFunction::ArraySlice => Self::ArraySlice, BuiltinScalarFunction::ArrayToString => Self::ArrayToString, + BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect, + BuiltinScalarFunction::ArrayUnion => Self::ArrayUnion, + BuiltinScalarFunction::Range => Self::Range, BuiltinScalarFunction::Cardinality => Self::Cardinality, BuiltinScalarFunction::MakeArray => Self::Array, BuiltinScalarFunction::NullIf => Self::NullIf, @@ -1524,6 +1575,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Substr => Self::Substr, BuiltinScalarFunction::ToHex => Self::ToHex, BuiltinScalarFunction::ToTimestampMicros => Self::ToTimestampMicros, + BuiltinScalarFunction::ToTimestampNanos => Self::ToTimestampNanos, BuiltinScalarFunction::ToTimestampSeconds => Self::ToTimestampSeconds, BuiltinScalarFunction::Now => Self::Now, BuiltinScalarFunction::CurrentDate => Self::CurrentDate, @@ -1540,6 +1592,10 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Isnan => Self::Isnan, BuiltinScalarFunction::Iszero => Self::Iszero, BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, + BuiltinScalarFunction::OverLay => Self::OverLay, + BuiltinScalarFunction::Levenshtein => Self::Levenshtein, + BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex, + BuiltinScalarFunction::FindInSet => Self::FindInSet, }; Ok(scalar_function) @@ -1623,6 +1679,35 @@ impl From for protobuf::JoinConstraint { } } +impl From for protobuf::Constraints { + fn from(value: Constraints) -> Self { + let constraints = value.into_iter().map(|item| item.into()).collect(); + protobuf::Constraints { constraints } + } +} + +impl From for protobuf::Constraint { + fn from(value: Constraint) -> Self { + let res = match value { + Constraint::PrimaryKey(indices) => { + let indices = indices.into_iter().map(|item| item as u64).collect(); + protobuf::constraint::ConstraintMode::PrimaryKey( + protobuf::PrimaryKeyConstraint { indices }, + ) + } + Constraint::Unique(indices) => { + let indices = indices.into_iter().map(|item| item as u64).collect(); + protobuf::constraint::ConstraintMode::PrimaryKey( + protobuf::PrimaryKeyConstraint { indices }, + ) + } + }; + protobuf::Constraint { + constraint_mode: Some(res), + } + } +} + /// Creates a scalar protobuf value from an optional value (T), and /// encoding None as the appropriate datatype fn create_proto_scalar protobuf::scalar_value::Value>( diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index bb38116e5dba..23ab813ca739 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -17,43 +17,50 @@ //! Serde code to convert from protocol buffers to Rust data structures. -use crate::protobuf; -use arrow::datatypes::DataType; -use chrono::TimeZone; -use chrono::Utc; +use std::convert::{TryFrom, TryInto}; +use std::sync::Arc; + +use arrow::compute::SortOptions; use datafusion::arrow::datatypes::Schema; -use datafusion::datasource::listing::{FileRange, PartitionedFile}; +use datafusion::datasource::file_format::csv::CsvSink; +use datafusion::datasource::file_format::json::JsonSink; +#[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetSink; +use datafusion::datasource::listing::{FileRange, ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; -use datafusion::datasource::physical_plan::FileScanConfig; +use datafusion::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; use datafusion::execution::context::ExecutionProps; use datafusion::execution::FunctionRegistry; -use datafusion::logical_expr::window_function::WindowFunction; +use datafusion::logical_expr::WindowFunctionDefinition; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; -use datafusion::physical_plan::expressions::{in_list, LikeExpr}; +use datafusion::physical_plan::expressions::{ + in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, + Literal, NegativeExpr, NotExpr, TryCastExpr, +}; use datafusion::physical_plan::expressions::{GetFieldAccessExpr, GetIndexedFieldExpr}; use datafusion::physical_plan::windows::create_window_expr; -use datafusion::physical_plan::WindowExpr; use datafusion::physical_plan::{ - expressions::{ - BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, Literal, - NegativeExpr, NotExpr, TryCastExpr, - }, - functions, Partitioning, + functions, ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, +}; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; +use datafusion_common::file_options::json_writer::JsonWriterOptions; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::stats::Precision; +use datafusion_common::{ + not_impl_err, DataFusionError, FileTypeWriterOptions, JoinSide, Result, ScalarValue, }; -use datafusion::physical_plan::{ColumnStatistics, PhysicalExpr, Statistics}; -use datafusion_common::{not_impl_err, DataFusionError, Result}; -use object_store::path::Path; -use object_store::ObjectMeta; -use std::convert::{TryFrom, TryInto}; -use std::ops::Deref; -use std::sync::Arc; use crate::common::proto_error; use crate::convert_required; use crate::logical_plan; +use crate::protobuf; use crate::protobuf::physical_expr_node::ExprType; -use datafusion::physical_plan::joins::utils::JoinSide; -use datafusion::physical_plan::sorts::sort::SortOptions; + +use crate::logical_plan::{csv_writer_options_from_proto, writer_properties_from_proto}; +use chrono::{TimeZone, Utc}; +use object_store::path::Path; +use object_store::ObjectMeta; impl From<&protobuf::PhysicalColumn> for Column { fn from(c: &protobuf::PhysicalColumn) -> Column { @@ -311,12 +318,12 @@ pub fn parse_physical_expr( &e.name, fun_expr, args, - &convert_required!(e.return_type)?, + convert_required!(e.return_type)?, None, )) } ExprType::ScalarUdf(e) => { - let scalar_fun = registry.udf(e.name.as_str())?.deref().clone().fun; + let scalar_fun = registry.udf(e.name.as_str())?.fun().clone(); let args = e .args @@ -328,7 +335,7 @@ pub fn parse_physical_expr( e.name.as_str(), scalar_fun, args, - &convert_required!(e.return_type)?, + convert_required!(e.return_type)?, None, )) } @@ -407,7 +414,9 @@ fn parse_required_physical_expr( }) } -impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFunction { +impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> + for WindowFunctionDefinition +{ type Error = DataFusionError; fn try_from( @@ -421,7 +430,7 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun )) })?; - Ok(WindowFunction::AggregateFunction(f.into())) + Ok(WindowFunctionDefinition::AggregateFunction(f.into())) } protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => { let f = protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| { @@ -430,7 +439,7 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun )) })?; - Ok(WindowFunction::BuiltInWindowFunction(f.into())) + Ok(WindowFunctionDefinition::BuiltInWindowFunction(f.into())) } } } @@ -490,13 +499,8 @@ pub fn parse_protobuf_file_scan_config( let table_partition_cols = proto .table_partition_cols .iter() - .map(|col| { - Ok(( - col.to_owned(), - schema.field_with_name(col)?.data_type().clone(), - )) - }) - .collect::>>()?; + .map(|col| Ok(schema.field_with_name(col)?.clone())) + .collect::>>()?; let mut output_ordering = vec![]; for node_collection in &proto.output_ordering { @@ -530,7 +534,6 @@ pub fn parse_protobuf_file_scan_config( limit: proto.limit.as_ref().map(|sl| sl.limit as usize), table_partition_cols, output_ordering, - infinite_source: false, }) } @@ -544,6 +547,7 @@ impl TryFrom<&protobuf::PartitionedFile> for PartitionedFile { last_modified: Utc.timestamp_nanos(val.last_modified_ns as i64), size: val.size as usize, e_tag: None, + version: None, }, partition_values: val .partition_values @@ -581,10 +585,96 @@ impl TryFrom<&protobuf::FileGroup> for Vec { impl From<&protobuf::ColumnStats> for ColumnStatistics { fn from(cs: &protobuf::ColumnStats) -> ColumnStatistics { ColumnStatistics { - null_count: Some(cs.null_count as usize), - max_value: cs.max_value.as_ref().map(|m| m.try_into().unwrap()), - min_value: cs.min_value.as_ref().map(|m| m.try_into().unwrap()), - distinct_count: Some(cs.distinct_count as usize), + null_count: if let Some(nc) = &cs.null_count { + nc.clone().into() + } else { + Precision::Absent + }, + max_value: if let Some(max) = &cs.max_value { + max.clone().into() + } else { + Precision::Absent + }, + min_value: if let Some(min) = &cs.min_value { + min.clone().into() + } else { + Precision::Absent + }, + distinct_count: if let Some(dc) = &cs.distinct_count { + dc.clone().into() + } else { + Precision::Absent + }, + } + } +} + +impl From for Precision { + fn from(s: protobuf::Precision) -> Self { + let Ok(precision_type) = s.precision_info.try_into() else { + return Precision::Absent; + }; + match precision_type { + protobuf::PrecisionInfo::Exact => { + if let Some(val) = s.val { + if let Ok(ScalarValue::UInt64(Some(val))) = + ScalarValue::try_from(&val) + { + Precision::Exact(val as usize) + } else { + Precision::Absent + } + } else { + Precision::Absent + } + } + protobuf::PrecisionInfo::Inexact => { + if let Some(val) = s.val { + if let Ok(ScalarValue::UInt64(Some(val))) = + ScalarValue::try_from(&val) + { + Precision::Inexact(val as usize) + } else { + Precision::Absent + } + } else { + Precision::Absent + } + } + protobuf::PrecisionInfo::Absent => Precision::Absent, + } + } +} + +impl From for Precision { + fn from(s: protobuf::Precision) -> Self { + let Ok(precision_type) = s.precision_info.try_into() else { + return Precision::Absent; + }; + match precision_type { + protobuf::PrecisionInfo::Exact => { + if let Some(val) = s.val { + if let Ok(val) = ScalarValue::try_from(&val) { + Precision::Exact(val) + } else { + Precision::Absent + } + } else { + Precision::Absent + } + } + protobuf::PrecisionInfo::Inexact => { + if let Some(val) = s.val { + if let Ok(val) = ScalarValue::try_from(&val) { + Precision::Inexact(val) + } else { + Precision::Absent + } + } else { + Precision::Absent + } + } + protobuf::PrecisionInfo::Absent => Precision::Absent, } } } @@ -603,27 +693,119 @@ impl TryFrom<&protobuf::Statistics> for Statistics { fn try_from(s: &protobuf::Statistics) -> Result { // Keep it sync with Statistics::to_proto - let none_value = -1_i64; - let column_statistics = - s.column_stats.iter().map(|s| s.into()).collect::>(); Ok(Statistics { - num_rows: if s.num_rows == none_value { - None + num_rows: if let Some(nr) = &s.num_rows { + nr.clone().into() } else { - Some(s.num_rows as usize) + Precision::Absent }, - total_byte_size: if s.total_byte_size == none_value { - None + total_byte_size: if let Some(tbs) = &s.total_byte_size { + tbs.clone().into() } else { - Some(s.total_byte_size as usize) + Precision::Absent }, // No column statistic (None) is encoded with empty array - column_statistics: if column_statistics.is_empty() { - None - } else { - Some(column_statistics) - }, - is_exact: s.is_exact, + column_statistics: s.column_stats.iter().map(|s| s.into()).collect(), + }) + } +} + +impl TryFrom<&protobuf::JsonSink> for JsonSink { + type Error = DataFusionError; + + fn try_from(value: &protobuf::JsonSink) -> Result { + Ok(Self::new(convert_required!(value.config)?)) + } +} + +#[cfg(feature = "parquet")] +impl TryFrom<&protobuf::ParquetSink> for ParquetSink { + type Error = DataFusionError; + + fn try_from(value: &protobuf::ParquetSink) -> Result { + Ok(Self::new(convert_required!(value.config)?)) + } +} + +impl TryFrom<&protobuf::CsvSink> for CsvSink { + type Error = DataFusionError; + + fn try_from(value: &protobuf::CsvSink) -> Result { + Ok(Self::new(convert_required!(value.config)?)) + } +} + +impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { + type Error = DataFusionError; + + fn try_from(conf: &protobuf::FileSinkConfig) -> Result { + let file_groups = conf + .file_groups + .iter() + .map(TryInto::try_into) + .collect::>>()?; + let table_paths = conf + .table_paths + .iter() + .map(ListingTableUrl::parse) + .collect::>>()?; + let table_partition_cols = conf + .table_partition_cols + .iter() + .map(|protobuf::PartitionColumn { name, arrow_type }| { + let data_type = convert_required!(arrow_type)?; + Ok((name.clone(), data_type)) + }) + .collect::>>()?; + Ok(Self { + object_store_url: ObjectStoreUrl::parse(&conf.object_store_url)?, + file_groups, + table_paths, + output_schema: Arc::new(convert_required!(conf.output_schema)?), + table_partition_cols, + single_file_output: conf.single_file_output, + overwrite: conf.overwrite, + file_type_writer_options: convert_required!(conf.file_type_writer_options)?, }) } } + +impl From for CompressionTypeVariant { + fn from(value: protobuf::CompressionTypeVariant) -> Self { + match value { + protobuf::CompressionTypeVariant::Gzip => Self::GZIP, + protobuf::CompressionTypeVariant::Bzip2 => Self::BZIP2, + protobuf::CompressionTypeVariant::Xz => Self::XZ, + protobuf::CompressionTypeVariant::Zstd => Self::ZSTD, + protobuf::CompressionTypeVariant::Uncompressed => Self::UNCOMPRESSED, + } + } +} + +impl TryFrom<&protobuf::FileTypeWriterOptions> for FileTypeWriterOptions { + type Error = DataFusionError; + + fn try_from(value: &protobuf::FileTypeWriterOptions) -> Result { + let file_type = value + .file_type + .as_ref() + .ok_or_else(|| proto_error("Missing required file_type field in protobuf"))?; + + match file_type { + protobuf::file_type_writer_options::FileType::JsonOptions(opts) => { + let compression: CompressionTypeVariant = opts.compression().into(); + Ok(Self::JSON(JsonWriterOptions::new(compression))) + } + protobuf::file_type_writer_options::FileType::CsvOptions(opts) => { + let write_options = csv_writer_options_from_proto(opts)?; + let compression: CompressionTypeVariant = opts.compression().into(); + Ok(Self::CSV(CsvWriterOptions::new(write_options, compression))) + } + protobuf::file_type_writer_options::FileType::ParquetOptions(opt) => { + let props = opt.writer_properties.clone().unwrap_or_default(); + let writer_properties = writer_properties_from_proto(&props)?; + Ok(Self::Parquet(ParquetWriterOptions::new(writer_properties))) + } + } + } +} diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 8257f9aa3458..95becb3fe4b3 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -21,8 +21,14 @@ use std::sync::Arc; use datafusion::arrow::compute::SortOptions; use datafusion::arrow::datatypes::SchemaRef; +use datafusion::datasource::file_format::csv::CsvSink; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; -use datafusion::datasource::physical_plan::{AvroExec, CsvExec, ParquetExec}; +use datafusion::datasource::file_format::json::JsonSink; +#[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetSink; +#[cfg(feature = "parquet")] +use datafusion::datasource::physical_plan::ParquetExec; +use datafusion::datasource::physical_plan::{AvroExec, CsvExec}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::FunctionRegistry; use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateMode}; @@ -34,20 +40,23 @@ use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::explain::ExplainExec; use datafusion::physical_plan::expressions::{Column, PhysicalSortExpr}; use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::insert::FileSinkExec; use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; -use datafusion::physical_plan::joins::{CrossJoinExec, NestedLoopJoinExec}; +use datafusion::physical_plan::joins::{ + CrossJoinExec, NestedLoopJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, +}; use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion::physical_plan::union::UnionExec; -use datafusion::physical_plan::windows::{ - BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, -}; +use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; +use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ - udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr, + udaf, AggregateExpr, ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, + WindowExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use prost::bytes::BufMut; @@ -62,7 +71,9 @@ use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; use crate::protobuf::physical_expr_node::ExprType; use crate::protobuf::physical_plan_node::PhysicalPlanType; use crate::protobuf::repartition_exec_node::PartitionMethod; -use crate::protobuf::{self, window_agg_exec_node, PhysicalPlanNode}; +use crate::protobuf::{ + self, window_agg_exec_node, PhysicalPlanNode, PhysicalSortExprNodeCollection, +}; use crate::{convert_required, into_required}; use self::from_proto::parse_physical_window_expr; @@ -151,7 +162,16 @@ impl AsExecutionPlan for PhysicalPlanNode { .to_owned(), ) })?; - Ok(Arc::new(FilterExec::try_new(predicate, input)?)) + let filter_selectivity = filter.default_filter_selectivity.try_into(); + let filter = FilterExec::try_new(predicate, input)?; + match filter_selectivity { + Ok(filter_selectivity) => Ok(Arc::new( + filter.with_default_selectivity(filter_selectivity)?, + )), + Err(_) => Err(DataFusionError::Internal( + "filter_selectivity in PhysicalPlanNode is invalid ".to_owned(), + )), + } } PhysicalPlanType::CsvScan(scan) => Ok(Arc::new(CsvExec::new( parse_protobuf_file_scan_config( @@ -171,6 +191,7 @@ impl AsExecutionPlan for PhysicalPlanNode { }, FileCompressionType::UNCOMPRESSED, ))), + #[cfg(feature = "parquet")] PhysicalPlanType::ParquetScan(scan) => { let base_config = parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), @@ -282,16 +303,7 @@ impl AsExecutionPlan for PhysicalPlanNode { runtime, extension_codec, )?; - let input_schema = window_agg - .input_schema - .as_ref() - .ok_or_else(|| { - DataFusionError::Internal( - "input_schema in WindowAggrNode is missing.".to_owned(), - ) - })? - .clone(); - let input_schema: SchemaRef = SchemaRef::new((&input_schema).try_into()?); + let input_schema = input.schema(); let physical_window_expr: Vec> = window_agg .window_expr @@ -313,35 +325,31 @@ impl AsExecutionPlan for PhysicalPlanNode { }) .collect::>>>()?; - if let Some(partition_search_mode) = - window_agg.partition_search_mode.as_ref() - { - let partition_search_mode = match partition_search_mode { - window_agg_exec_node::PartitionSearchMode::Linear(_) => { - PartitionSearchMode::Linear + if let Some(input_order_mode) = window_agg.input_order_mode.as_ref() { + let input_order_mode = match input_order_mode { + window_agg_exec_node::InputOrderMode::Linear(_) => { + InputOrderMode::Linear } - window_agg_exec_node::PartitionSearchMode::PartiallySorted( - protobuf::PartiallySortedPartitionSearchMode { columns }, - ) => PartitionSearchMode::PartiallySorted( + window_agg_exec_node::InputOrderMode::PartiallySorted( + protobuf::PartiallySortedInputOrderMode { columns }, + ) => InputOrderMode::PartiallySorted( columns.iter().map(|c| *c as usize).collect(), ), - window_agg_exec_node::PartitionSearchMode::Sorted(_) => { - PartitionSearchMode::Sorted + window_agg_exec_node::InputOrderMode::Sorted(_) => { + InputOrderMode::Sorted } }; Ok(Arc::new(BoundedWindowAggExec::try_new( physical_window_expr, input, - input_schema, partition_keys, - partition_search_mode, + input_order_mode, )?)) } else { Ok(Arc::new(WindowAggExec::try_new( physical_window_expr, input, - input_schema, partition_keys, )?)) } @@ -405,17 +413,12 @@ impl AsExecutionPlan for PhysicalPlanNode { vec![] }; - let input_schema = hash_agg - .input_schema - .as_ref() - .ok_or_else(|| { - DataFusionError::Internal( - "input_schema in AggregateNode is missing.".to_owned(), - ) - })? - .clone(); - let physical_schema: SchemaRef = - SchemaRef::new((&input_schema).try_into()?); + let input_schema = hash_agg.input_schema.as_ref().ok_or_else(|| { + DataFusionError::Internal( + "input_schema in AggregateNode is missing.".to_owned(), + ) + })?; + let physical_schema: SchemaRef = SchemaRef::new(input_schema.try_into()?); let physical_filter_expr = hash_agg .filter_expr @@ -427,19 +430,6 @@ impl AsExecutionPlan for PhysicalPlanNode { .transpose() }) .collect::, _>>()?; - let physical_order_by_expr = hash_agg - .order_by_expr - .iter() - .map(|expr| { - expr.sort_expr - .iter() - .map(|e| { - parse_physical_sort_expr(e, registry, &physical_schema) - }) - .collect::>>() - .map(|exprs| (!exprs.is_empty()).then_some(exprs)) - }) - .collect::>>()?; let physical_aggr_expr: Vec> = hash_agg .aggr_expr @@ -498,9 +488,8 @@ impl AsExecutionPlan for PhysicalPlanNode { PhysicalGroupBy::new(group_expr, null_expr, groups), physical_aggr_expr, physical_filter_expr, - physical_order_by_expr, input, - Arc::new((&input_schema).try_into()?), + physical_schema, )?)) } PhysicalPlanType::HashJoin(hashjoin) => { @@ -546,7 +535,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -557,7 +546,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -592,6 +581,97 @@ impl AsExecutionPlan for PhysicalPlanNode { hashjoin.null_equals_null, )?)) } + PhysicalPlanType::SymmetricHashJoin(sym_join) => { + let left = into_physical_plan( + &sym_join.left, + registry, + runtime, + extension_codec, + )?; + let right = into_physical_plan( + &sym_join.right, + registry, + runtime, + extension_codec, + )?; + let on = sym_join + .on + .iter() + .map(|col| { + let left = into_required!(col.left)?; + let right = into_required!(col.right)?; + Ok((left, right)) + }) + .collect::>()?; + let join_type = protobuf::JoinType::try_from(sym_join.join_type) + .map_err(|_| { + proto_error(format!( + "Received a SymmetricHashJoin message with unknown JoinType {}", + sym_join.join_type + )) + })?; + let filter = sym_join + .filter + .as_ref() + .map(|f| { + let schema = f + .schema + .as_ref() + .ok_or_else(|| proto_error("Missing JoinFilter schema"))? + .try_into()?; + + let expression = parse_physical_expr( + f.expression.as_ref().ok_or_else(|| { + proto_error("Unexpected empty filter expression") + })?, + registry, &schema, + )?; + let column_indices = f.column_indices + .iter() + .map(|i| { + let side = protobuf::JoinSide::try_from(i.side) + .map_err(|_| proto_error(format!( + "Received a HashJoinNode message with JoinSide in Filter {}", + i.side)) + )?; + + Ok(ColumnIndex { + index: i.index as usize, + side: side.into(), + }) + }) + .collect::>()?; + + Ok(JoinFilter::new(expression, column_indices, schema)) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + let partition_mode = + protobuf::StreamPartitionMode::try_from(sym_join.partition_mode).map_err(|_| { + proto_error(format!( + "Received a SymmetricHashJoin message with unknown PartitionMode {}", + sym_join.partition_mode + )) + })?; + let partition_mode = match partition_mode { + protobuf::StreamPartitionMode::SinglePartition => { + StreamJoinPartitionMode::SinglePartition + } + protobuf::StreamPartitionMode::PartitionedExec => { + StreamJoinPartitionMode::Partitioned + } + }; + SymmetricHashJoinExec::try_new( + left, + right, + on, + filter, + &join_type.into(), + sym_join.null_equals_null, + partition_mode, + ) + .map(|e| Arc::new(e) as _) + } PhysicalPlanType::Union(union) => { let mut inputs: Vec> = vec![]; for input in &union.inputs { @@ -603,6 +683,17 @@ impl AsExecutionPlan for PhysicalPlanNode { } Ok(Arc::new(UnionExec::new(inputs))) } + PhysicalPlanType::Interleave(interleave) => { + let mut inputs: Vec> = vec![]; + for input in &interleave.inputs { + inputs.push(input.try_into_physical_plan( + registry, + runtime, + extension_codec, + )?); + } + Ok(Arc::new(InterleaveExec::try_new(inputs)?)) + } PhysicalPlanType::CrossJoin(crossjoin) => { let left: Arc = into_physical_plan( &crossjoin.left, @@ -620,7 +711,11 @@ impl AsExecutionPlan for PhysicalPlanNode { } PhysicalPlanType::Empty(empty) => { let schema = Arc::new(convert_required!(empty.schema)?); - Ok(Arc::new(EmptyExec::new(empty.produce_one_row, schema))) + Ok(Arc::new(EmptyExec::new(schema))) + } + PhysicalPlanType::PlaceholderRow(placeholder) => { + let schema = Arc::new(convert_required!(placeholder.schema)?); + Ok(Arc::new(PlaceholderRowExec::new(schema))) } PhysicalPlanType::Sort(sort) => { let input: Arc = @@ -645,7 +740,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr,registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -692,7 +787,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr,registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -755,7 +850,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -766,7 +861,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -795,7 +890,100 @@ impl AsExecutionPlan for PhysicalPlanNode { analyze.verbose, analyze.show_statistics, input, - Arc::new(analyze.schema.as_ref().unwrap().try_into()?), + Arc::new(convert_required!(analyze.schema)?), + ))) + } + PhysicalPlanType::JsonSink(sink) => { + let input = + into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + + let data_sink: JsonSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = convert_required!(sink.sink_schema)?; + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + collection + .physical_sort_expr_nodes + .iter() + .map(|proto| { + parse_physical_sort_expr(proto, registry, &sink_schema) + .map(Into::into) + }) + .collect::>>() + }) + .transpose()?; + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(data_sink), + Arc::new(sink_schema), + sort_order, + ))) + } + PhysicalPlanType::CsvSink(sink) => { + let input = + into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + + let data_sink: CsvSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = convert_required!(sink.sink_schema)?; + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + collection + .physical_sort_expr_nodes + .iter() + .map(|proto| { + parse_physical_sort_expr(proto, registry, &sink_schema) + .map(Into::into) + }) + .collect::>>() + }) + .transpose()?; + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(data_sink), + Arc::new(sink_schema), + sort_order, + ))) + } + PhysicalPlanType::ParquetSink(sink) => { + let input = + into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + + let data_sink: ParquetSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = convert_required!(sink.sink_schema)?; + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + collection + .physical_sort_expr_nodes + .iter() + .map(|proto| { + parse_physical_sort_expr(proto, registry, &sink_schema) + .map(Into::into) + }) + .collect::>>() + }) + .transpose()?; + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(data_sink), + Arc::new(sink_schema), + sort_order, ))) } } @@ -812,7 +1000,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let plan = plan.as_any(); if let Some(exec) = plan.downcast_ref::() { - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Explain( protobuf::ExplainExecNode { schema: Some(exec.schema().as_ref().try_into()?), @@ -824,8 +1012,10 @@ impl AsExecutionPlan for PhysicalPlanNode { verbose: exec.verbose(), }, )), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, @@ -836,7 +1026,7 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|expr| expr.0.clone().try_into()) .collect::>>()?; let expr_name = exec.expr().iter().map(|expr| expr.1.clone()).collect(); - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Projection(Box::new( protobuf::ProjectionExecNode { input: Some(Box::new(input)), @@ -844,13 +1034,15 @@ impl AsExecutionPlan for PhysicalPlanNode { expr_name, }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, )?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Analyze(Box::new( protobuf::AnalyzeExecNode { verbose: exec.verbose(), @@ -859,27 +1051,32 @@ impl AsExecutionPlan for PhysicalPlanNode { schema: Some(exec.schema().as_ref().try_into()?), }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, )?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Filter(Box::new( protobuf::FilterExecNode { input: Some(Box::new(input)), expr: Some(exec.predicate().clone().try_into()?), + default_filter_selectivity: exec.default_selectivity() as u32, }, ))), - }) - } else if let Some(limit) = plan.downcast_ref::() { + }); + } + + if let Some(limit) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( limit.input().to_owned(), extension_codec, )?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::GlobalLimit(Box::new( protobuf::GlobalLimitExecNode { input: Some(Box::new(input)), @@ -890,21 +1087,25 @@ impl AsExecutionPlan for PhysicalPlanNode { }, }, ))), - }) - } else if let Some(limit) = plan.downcast_ref::() { + }); + } + + if let Some(limit) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( limit.input().to_owned(), extension_codec, )?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::LocalLimit(Box::new( protobuf::LocalLimitExecNode { input: Some(Box::new(input)), fetch: limit.fetch() as u32, }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let left = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.left().to_owned(), extension_codec, @@ -959,7 +1160,7 @@ impl AsExecutionPlan for PhysicalPlanNode { PartitionMode::Auto => protobuf::PartitionMode::Auto, }; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new( protobuf::HashJoinExecNode { left: Some(Box::new(left)), @@ -971,8 +1172,83 @@ impl AsExecutionPlan for PhysicalPlanNode { filter, }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.left().to_owned(), + extension_codec, + )?; + let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.right().to_owned(), + extension_codec, + )?; + let on = exec + .on() + .iter() + .map(|tuple| protobuf::JoinOn { + left: Some(protobuf::PhysicalColumn { + name: tuple.0.name().to_string(), + index: tuple.0.index() as u32, + }), + right: Some(protobuf::PhysicalColumn { + name: tuple.1.name().to_string(), + index: tuple.1.index() as u32, + }), + }) + .collect(); + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let filter = exec + .filter() + .as_ref() + .map(|f| { + let expression = f.expression().to_owned().try_into()?; + let column_indices = f + .column_indices() + .iter() + .map(|i| { + let side: protobuf::JoinSide = i.side.to_owned().into(); + protobuf::ColumnIndex { + index: i.index as u32, + side: side.into(), + } + }) + .collect(); + let schema = f.schema().try_into()?; + Ok(protobuf::JoinFilter { + expression: Some(expression), + column_indices, + schema: Some(schema), + }) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + let partition_mode = match exec.partition_mode() { + StreamJoinPartitionMode::SinglePartition => { + protobuf::StreamPartitionMode::SinglePartition + } + StreamJoinPartitionMode::Partitioned => { + protobuf::StreamPartitionMode::PartitionedExec + } + }; + + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::SymmetricHashJoin(Box::new( + protobuf::SymmetricHashJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + on, + join_type: join_type.into(), + partition_mode: partition_mode.into(), + null_equals_null: exec.null_equals_null(), + filter, + }, + ))), + }); + } + + if let Some(exec) = plan.downcast_ref::() { let left = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.left().to_owned(), extension_codec, @@ -981,15 +1257,16 @@ impl AsExecutionPlan for PhysicalPlanNode { exec.right().to_owned(), extension_codec, )?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CrossJoin(Box::new( protobuf::CrossJoinExecNode { left: Some(Box::new(left)), right: Some(Box::new(right)), }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + if let Some(exec) = plan.downcast_ref::() { let groups: Vec = exec .group_expr() .groups() @@ -1011,12 +1288,6 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|expr| expr.to_owned().try_into()) .collect::>>()?; - let order_by = exec - .order_by_expr() - .iter() - .map(|expr| expr.to_owned().try_into()) - .collect::>>()?; - let agg = exec .aggr_expr() .iter() @@ -1062,14 +1333,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|expr| expr.0.to_owned().try_into()) .collect::>>()?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Aggregate(Box::new( protobuf::AggregateExecNode { group_expr, group_expr_name: group_names, aggr_expr: agg, filter_expr: filter, - order_by_expr: order_by, aggr_expr_name: agg_names, mode: agg_mode as i32, input: Some(Box::new(input)), @@ -1078,33 +1348,48 @@ impl AsExecutionPlan for PhysicalPlanNode { groups, }, ))), - }) - } else if let Some(empty) = plan.downcast_ref::() { + }); + } + + if let Some(empty) = plan.downcast_ref::() { let schema = empty.schema().as_ref().try_into()?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Empty( protobuf::EmptyExecNode { - produce_one_row: empty.produce_one_row(), schema: Some(schema), }, )), - }) - } else if let Some(coalesce_batches) = plan.downcast_ref::() - { + }); + } + + if let Some(empty) = plan.downcast_ref::() { + let schema = empty.schema().as_ref().try_into()?; + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::PlaceholderRow( + protobuf::PlaceholderRowExecNode { + schema: Some(schema), + }, + )), + }); + } + + if let Some(coalesce_batches) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( coalesce_batches.input().to_owned(), extension_codec, )?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CoalesceBatches(Box::new( protobuf::CoalesceBatchesExecNode { input: Some(Box::new(input)), target_batch_size: coalesce_batches.target_batch_size() as u32, }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { - Ok(protobuf::PhysicalPlanNode { + }); + } + + if let Some(exec) = plan.downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CsvScan( protobuf::CsvScanExecNode { base_conf: Some(exec.base_config().try_into()?), @@ -1120,41 +1405,50 @@ impl AsExecutionPlan for PhysicalPlanNode { }, }, )), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + #[cfg(feature = "parquet")] + if let Some(exec) = plan.downcast_ref::() { let predicate = exec .predicate() .map(|pred| pred.clone().try_into()) .transpose()?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( protobuf::ParquetScanExecNode { base_conf: Some(exec.base_config().try_into()?), predicate, }, )), - }) - } else if let Some(exec) = plan.downcast_ref::() { - Ok(protobuf::PhysicalPlanNode { + }); + } + + if let Some(exec) = plan.downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::AvroScan( protobuf::AvroScanExecNode { base_conf: Some(exec.base_config().try_into()?), }, )), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, )?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Merge(Box::new( protobuf::CoalescePartitionsExecNode { input: Some(Box::new(input)), }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, @@ -1178,15 +1472,17 @@ impl AsExecutionPlan for PhysicalPlanNode { } }; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Repartition(Box::new( protobuf::RepartitionExecNode { input: Some(Box::new(input)), partition_method: Some(pb_partition_method), }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, @@ -1207,7 +1503,7 @@ impl AsExecutionPlan for PhysicalPlanNode { }) }) .collect::>>()?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Sort(Box::new( protobuf::SortExecNode { input: Some(Box::new(input)), @@ -1219,8 +1515,10 @@ impl AsExecutionPlan for PhysicalPlanNode { preserve_partitioning: exec.preserve_partitioning(), }, ))), - }) - } else if let Some(union) = plan.downcast_ref::() { + }); + } + + if let Some(union) = plan.downcast_ref::() { let mut inputs: Vec = vec![]; for input in union.inputs() { inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( @@ -1228,12 +1526,29 @@ impl AsExecutionPlan for PhysicalPlanNode { extension_codec, )?); } - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Union( protobuf::UnionExecNode { inputs }, )), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(interleave) = plan.downcast_ref::() { + let mut inputs: Vec = vec![]; + for input in interleave.inputs() { + inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( + input.to_owned(), + extension_codec, + )?); + } + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Interleave( + protobuf::InterleaveExecNode { inputs }, + )), + }); + } + + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, @@ -1254,7 +1569,7 @@ impl AsExecutionPlan for PhysicalPlanNode { }) }) .collect::>>()?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::SortPreservingMerge( Box::new(protobuf::SortPreservingMergeExecNode { input: Some(Box::new(input)), @@ -1262,8 +1577,10 @@ impl AsExecutionPlan for PhysicalPlanNode { fetch: exec.fetch().map(|f| f as i64).unwrap_or(-1), }), )), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let left = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.left().to_owned(), extension_codec, @@ -1299,7 +1616,7 @@ impl AsExecutionPlan for PhysicalPlanNode { }) .map_or(Ok(None), |v: Result| v.map(Some))?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::NestedLoopJoin(Box::new( protobuf::NestedLoopJoinExecNode { left: Some(Box::new(left)), @@ -1308,15 +1625,15 @@ impl AsExecutionPlan for PhysicalPlanNode { filter, }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, )?; - let input_schema = protobuf::Schema::try_from(exec.input_schema().as_ref())?; - let window_expr = exec.window_expr() .iter() @@ -1329,25 +1646,24 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|e| e.clone().try_into()) .collect::>>()?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Window(Box::new( protobuf::WindowAggExecNode { input: Some(Box::new(input)), window_expr, - input_schema: Some(input_schema), partition_keys, - partition_search_mode: None, + input_order_mode: None, }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, )?; - let input_schema = protobuf::Schema::try_from(exec.input_schema().as_ref())?; - let window_expr = exec.window_expr() .iter() @@ -1360,62 +1676,125 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|e| e.clone().try_into()) .collect::>>()?; - let partition_search_mode = match &exec.partition_search_mode { - PartitionSearchMode::Linear => { - window_agg_exec_node::PartitionSearchMode::Linear( - protobuf::EmptyMessage {}, - ) - } - PartitionSearchMode::PartiallySorted(columns) => { - window_agg_exec_node::PartitionSearchMode::PartiallySorted( - protobuf::PartiallySortedPartitionSearchMode { + let input_order_mode = match &exec.input_order_mode { + InputOrderMode::Linear => window_agg_exec_node::InputOrderMode::Linear( + protobuf::EmptyMessage {}, + ), + InputOrderMode::PartiallySorted(columns) => { + window_agg_exec_node::InputOrderMode::PartiallySorted( + protobuf::PartiallySortedInputOrderMode { columns: columns.iter().map(|c| *c as u64).collect(), }, ) } - PartitionSearchMode::Sorted => { - window_agg_exec_node::PartitionSearchMode::Sorted( - protobuf::EmptyMessage {}, - ) - } + InputOrderMode::Sorted => window_agg_exec_node::InputOrderMode::Sorted( + protobuf::EmptyMessage {}, + ), }; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Window(Box::new( protobuf::WindowAggExecNode { input: Some(Box::new(input)), window_expr, - input_schema: Some(input_schema), partition_keys, - partition_search_mode: Some(partition_search_mode), + input_order_mode: Some(input_order_mode), }, ))), - }) - } else { - let mut buf: Vec = vec![]; - match extension_codec.try_encode(plan_clone.clone(), &mut buf) { - Ok(_) => { - let inputs: Vec = plan_clone - .children() - .into_iter() - .map(|i| { - protobuf::PhysicalPlanNode::try_from_physical_plan( - i, - extension_codec, - ) - }) - .collect::>()?; + }); + } - Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Extension( - protobuf::PhysicalExtensionNode { node: buf, inputs }, - )), + if let Some(exec) = plan.downcast_ref::() { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + let sort_order = match exec.sort_order() { + Some(requirements) => { + let expr = requirements + .iter() + .map(|requirement| { + let expr: PhysicalSortExpr = requirement.to_owned().into(); + let sort_expr = protobuf::PhysicalSortExprNode { + expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + asc: !expr.options.descending, + nulls_first: expr.options.nulls_first, + }; + Ok(sort_expr) + }) + .collect::>>()?; + Some(PhysicalSortExprNodeCollection { + physical_sort_expr_nodes: expr, }) } - Err(e) => internal_err!( - "Unsupported plan and extension codec failed with [{e}]. Plan: {plan_clone:?}" - ), + None => None, + }; + + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::JsonSink(Box::new( + protobuf::JsonSinkExecNode { + input: Some(Box::new(input)), + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, + }, + ))), + }); + } + + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::CsvSink(Box::new( + protobuf::CsvSinkExecNode { + input: Some(Box::new(input)), + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, + }, + ))), + }); + } + + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::ParquetSink(Box::new( + protobuf::ParquetSinkExecNode { + input: Some(Box::new(input)), + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, + }, + ))), + }); + } + + // If unknown DataSink then let extension handle it + } + + let mut buf: Vec = vec![]; + match extension_codec.try_encode(plan_clone.clone(), &mut buf) { + Ok(_) => { + let inputs: Vec = plan_clone + .children() + .into_iter() + .map(|i| { + protobuf::PhysicalPlanNode::try_from_physical_plan( + i, + extension_codec, + ) + }) + .collect::>()?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Extension( + protobuf::PhysicalExtensionNode { node: buf, inputs }, + )), + }) } + Err(e) => internal_err!( + "Unsupported plan and extension codec failed with [{e}]. Plan: {plan_clone:?}" + ), } } } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index cf3dbe26190a..f4e3f9e4dca7 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -22,58 +22,53 @@ use std::{ sync::Arc, }; -use datafusion::physical_plan::{ - expressions::{ - ApproxDistinct, ApproxMedian, ApproxPercentileCont, - ApproxPercentileContWithWeight, ArrayAgg, Correlation, Covariance, CovariancePop, - DistinctArrayAgg, DistinctBitXor, DistinctSum, FirstValue, Grouping, LastValue, - Median, OrderSensitiveArrayAgg, Regr, RegrType, Stddev, StddevPop, Variance, - VariancePop, - }, - windows::BuiltInWindowExpr, - ColumnStatistics, -}; -use datafusion::{ - physical_expr::window::NthValueKind, - physical_plan::{ - expressions::{ - CaseExpr, CumeDist, InListExpr, IsNotNullExpr, IsNullExpr, NegativeExpr, - NotExpr, NthValue, Ntile, Rank, RankType, RowNumber, WindowShift, - }, - Statistics, - }, -}; -use datafusion::{ - physical_expr::window::SlidingAggregateWindowExpr, - physical_plan::{ - expressions::{CastExpr, TryCastExpr}, - windows::PlainAggregateWindowExpr, - WindowExpr, - }, -}; - -use datafusion::datasource::listing::{FileRange, PartitionedFile}; -use datafusion::datasource::physical_plan::FileScanConfig; - -use datafusion::physical_plan::expressions::{Count, DistinctCount, Literal}; - -use datafusion::physical_plan::expressions::{ - Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Column, LikeExpr, Max, Min, - Sum, -}; -use datafusion::physical_plan::{AggregateExpr, PhysicalExpr}; - -use crate::protobuf::{self, physical_window_expr_node}; +use crate::protobuf::{self, physical_window_expr_node, scalar_value::Value}; use crate::protobuf::{ physical_aggregate_expr_node, PhysicalSortExprNode, PhysicalSortExprNodeCollection, ScalarValue, }; + +#[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetSink; + +use crate::logical_plan::{csv_writer_options_to_proto, writer_properties_to_proto}; +use datafusion::datasource::{ + file_format::csv::CsvSink, + file_format::json::JsonSink, + listing::{FileRange, PartitionedFile}, + physical_plan::FileScanConfig, + physical_plan::FileSinkConfig, +}; use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::physical_expr::expressions::{GetFieldAccessExpr, GetIndexedFieldExpr}; +use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; -use datafusion::physical_plan::joins::utils::JoinSide; +use datafusion::physical_plan::expressions::{ + ApproxDistinct, ApproxMedian, ApproxPercentileCont, ApproxPercentileContWithWeight, + ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, + CastExpr, Column, Correlation, Count, Covariance, CovariancePop, CumeDist, + DistinctArrayAgg, DistinctBitXor, DistinctCount, DistinctSum, FirstValue, Grouping, + InListExpr, IsNotNullExpr, IsNullExpr, LastValue, LikeExpr, Literal, Max, Median, + Min, NegativeExpr, NotExpr, NthValue, Ntile, OrderSensitiveArrayAgg, Rank, RankType, + Regr, RegrType, RowNumber, Stddev, StddevPop, Sum, TryCastExpr, Variance, + VariancePop, WindowShift, +}; use datafusion::physical_plan::udaf::AggregateFunctionExpr; -use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; +use datafusion::physical_plan::{ + AggregateExpr, ColumnStatistics, PhysicalExpr, Statistics, WindowExpr, +}; +use datafusion_common::{ + file_options::{ + arrow_writer::ArrowWriterOptions, avro_writer::AvroWriterOptions, + csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions, + parquet_writer::ParquetWriterOptions, + }, + internal_err, not_impl_err, + parsers::CompressionTypeVariant, + stats::Precision, + DataFusionError, FileTypeWriterOptions, JoinSide, Result, +}; impl TryFrom> for protobuf::PhysicalExprNode { type Error = DataFusionError; @@ -93,10 +88,11 @@ impl TryFrom> for protobuf::PhysicalExprNode { .collect::>>()?; if let Some(a) = a.as_any().downcast_ref::() { + let name = a.fun().name().to_string(); return Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { - aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(a.fun().name.clone())), + aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), expr: expressions, ordering_req, distinct: false, @@ -189,7 +185,7 @@ impl TryFrom> for protobuf::PhysicalWindowExprNode { args.insert( 1, Arc::new(Literal::new( - datafusion_common::ScalarValue::Int64(Some(n as i64)), + datafusion_common::ScalarValue::Int64(Some(n)), )), ); protobuf::BuiltInWindowFunction::NthValue @@ -644,29 +640,66 @@ impl TryFrom<&[PartitionedFile]> for protobuf::FileGroup { } } -impl From<&ColumnStatistics> for protobuf::ColumnStats { - fn from(cs: &ColumnStatistics) -> protobuf::ColumnStats { - protobuf::ColumnStats { - min_value: cs.min_value.as_ref().map(|m| m.try_into().unwrap()), - max_value: cs.max_value.as_ref().map(|m| m.try_into().unwrap()), - null_count: cs.null_count.map(|n| n as u32).unwrap_or(0), - distinct_count: cs.distinct_count.map(|n| n as u32).unwrap_or(0), +impl From<&Precision> for protobuf::Precision { + fn from(s: &Precision) -> protobuf::Precision { + match s { + Precision::Exact(val) => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Exact.into(), + val: Some(ScalarValue { + value: Some(Value::Uint64Value(*val as u64)), + }), + }, + Precision::Inexact(val) => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Inexact.into(), + val: Some(ScalarValue { + value: Some(Value::Uint64Value(*val as u64)), + }), + }, + Precision::Absent => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Absent.into(), + val: Some(ScalarValue { value: None }), + }, + } + } +} + +impl From<&Precision> for protobuf::Precision { + fn from(s: &Precision) -> protobuf::Precision { + match s { + Precision::Exact(val) => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Exact.into(), + val: val.try_into().ok(), + }, + Precision::Inexact(val) => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Inexact.into(), + val: val.try_into().ok(), + }, + Precision::Absent => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Absent.into(), + val: Some(ScalarValue { value: None }), + }, } } } impl From<&Statistics> for protobuf::Statistics { fn from(s: &Statistics) -> protobuf::Statistics { - let none_value = -1_i64; - let column_stats = match &s.column_statistics { - None => vec![], - Some(column_stats) => column_stats.iter().map(|s| s.into()).collect(), - }; + let column_stats = s.column_statistics.iter().map(|s| s.into()).collect(); protobuf::Statistics { - num_rows: s.num_rows.map(|n| n as i64).unwrap_or(none_value), - total_byte_size: s.total_byte_size.map(|n| n as i64).unwrap_or(none_value), + num_rows: Some(protobuf::Precision::from(&s.num_rows)), + total_byte_size: Some(protobuf::Precision::from(&s.total_byte_size)), column_stats, - is_exact: s.is_exact, + } + } +} + +impl From<&ColumnStatistics> for protobuf::ColumnStats { + fn from(s: &ColumnStatistics) -> protobuf::ColumnStats { + protobuf::ColumnStats { + min_value: Some(protobuf::Precision::from(&s.min_value)), + max_value: Some(protobuf::Precision::from(&s.max_value)), + null_count: Some(protobuf::Precision::from(&s.null_count)), + distinct_count: Some(protobuf::Precision::from(&s.distinct_count)), } } } @@ -713,7 +746,7 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { table_partition_cols: conf .table_partition_cols .iter() - .map(|x| x.0.clone()) + .map(|x| x.name().clone()) .collect::>(), object_store_url: conf.object_store_url.to_string(), output_ordering: output_orderings @@ -775,3 +808,126 @@ impl TryFrom for protobuf::PhysicalSortExprNode { }) } } + +impl TryFrom<&JsonSink> for protobuf::JsonSink { + type Error = DataFusionError; + + fn try_from(value: &JsonSink) -> Result { + Ok(Self { + config: Some(value.config().try_into()?), + }) + } +} + +impl TryFrom<&CsvSink> for protobuf::CsvSink { + type Error = DataFusionError; + + fn try_from(value: &CsvSink) -> Result { + Ok(Self { + config: Some(value.config().try_into()?), + }) + } +} + +#[cfg(feature = "parquet")] +impl TryFrom<&ParquetSink> for protobuf::ParquetSink { + type Error = DataFusionError; + + fn try_from(value: &ParquetSink) -> Result { + Ok(Self { + config: Some(value.config().try_into()?), + }) + } +} + +impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { + type Error = DataFusionError; + + fn try_from(conf: &FileSinkConfig) -> Result { + let file_groups = conf + .file_groups + .iter() + .map(TryInto::try_into) + .collect::>>()?; + let table_paths = conf + .table_paths + .iter() + .map(ToString::to_string) + .collect::>(); + let table_partition_cols = conf + .table_partition_cols + .iter() + .map(|(name, data_type)| { + Ok(protobuf::PartitionColumn { + name: name.to_owned(), + arrow_type: Some(data_type.try_into()?), + }) + }) + .collect::>>()?; + let file_type_writer_options = &conf.file_type_writer_options; + Ok(Self { + object_store_url: conf.object_store_url.to_string(), + file_groups, + table_paths, + output_schema: Some(conf.output_schema.as_ref().try_into()?), + table_partition_cols, + single_file_output: conf.single_file_output, + overwrite: conf.overwrite, + file_type_writer_options: Some(file_type_writer_options.try_into()?), + }) + } +} + +impl From<&CompressionTypeVariant> for protobuf::CompressionTypeVariant { + fn from(value: &CompressionTypeVariant) -> Self { + match value { + CompressionTypeVariant::GZIP => Self::Gzip, + CompressionTypeVariant::BZIP2 => Self::Bzip2, + CompressionTypeVariant::XZ => Self::Xz, + CompressionTypeVariant::ZSTD => Self::Zstd, + CompressionTypeVariant::UNCOMPRESSED => Self::Uncompressed, + } + } +} + +impl TryFrom<&FileTypeWriterOptions> for protobuf::FileTypeWriterOptions { + type Error = DataFusionError; + + fn try_from(opts: &FileTypeWriterOptions) -> Result { + let file_type = match opts { + #[cfg(feature = "parquet")] + FileTypeWriterOptions::Parquet(ParquetWriterOptions { writer_options }) => { + protobuf::file_type_writer_options::FileType::ParquetOptions( + protobuf::ParquetWriterOptions { + writer_properties: Some(writer_properties_to_proto( + writer_options, + )), + }, + ) + } + FileTypeWriterOptions::CSV(CsvWriterOptions { + writer_options, + compression, + }) => protobuf::file_type_writer_options::FileType::CsvOptions( + csv_writer_options_to_proto(writer_options, compression), + ), + FileTypeWriterOptions::JSON(JsonWriterOptions { compression }) => { + let compression: protobuf::CompressionTypeVariant = compression.into(); + protobuf::file_type_writer_options::FileType::JsonOptions( + protobuf::JsonWriterOptions { + compression: compression.into(), + }, + ) + } + FileTypeWriterOptions::Avro(AvroWriterOptions {}) => { + return not_impl_err!("Avro file sink protobuf serialization") + } + FileTypeWriterOptions::Arrow(ArrowWriterOptions {}) => { + return not_impl_err!("Arrow file sink protobuf serialization") + } + }; + Ok(Self { + file_type: Some(file_type), + }) + } +} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index cd294b0e535f..03daf535f201 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -15,37 +15,47 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::collections::HashMap; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; -use arrow::array::ArrayRef; +use arrow::array::{ArrayRef, FixedSizeListArray}; +use arrow::csv::WriterBuilder; use arrow::datatypes::{ - DataType, Field, Fields, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, - Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, + DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, + IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; + use prost::Message; use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext}; use datafusion::test_util::{TestTableFactory, TestTableProvider}; -use datafusion_common::Result; -use datafusion_common::{internal_err, not_impl_err, plan_err}; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; +use datafusion_common::file_options::StatementOptions; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::{internal_err, not_impl_err, plan_err, FileTypeWriterOptions}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue}; +use datafusion_common::{FileType, Result}; +use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, ScalarFunction, - ScalarUDF, Sort, + Sort, }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ col, create_udaf, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::{Sqrt, Substr}, Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, WindowUDF, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, + WindowUDFImpl, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -168,10 +178,10 @@ async fn roundtrip_custom_tables() -> Result<()> { let cfg = RuntimeConfig::new(); let env = RuntimeEnv::new(cfg).unwrap(); let ses = SessionConfig::new(); - let mut state = SessionState::with_config_rt(ses, Arc::new(env)); + let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); // replace factories *state.table_factories_mut() = table_factories; - let ctx = SessionContext::with_state(state); + let ctx = SessionContext::new_with_state(state); let sql = "CREATE EXTERNAL TABLE t STORED AS testtable LOCATION 's3://bucket/schema/table';"; ctx.sql(sql).await.unwrap(); @@ -185,6 +195,95 @@ async fn roundtrip_custom_tables() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_custom_memory_tables() -> Result<()> { + let ctx = SessionContext::new(); + // Make sure during round-trip, constraint information is preserved + let query = "CREATE TABLE sales_global_with_pk (zip_code INT, + country VARCHAR(3), + sn INT, + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT, + primary key(sn) + ) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0)"; + + let plan = ctx.sql(query).await?.into_optimized_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_custom_listing_tables() -> Result<()> { + let ctx = SessionContext::new(); + + let query = "CREATE EXTERNAL TABLE multiple_ordered_table_with_pk ( + a0 INTEGER, + a INTEGER DEFAULT 1*2 + 3, + b INTEGER DEFAULT NULL, + c INTEGER, + d INTEGER, + primary key(c) + ) + STORED AS CSV + WITH HEADER ROW + WITH ORDER (a ASC, b ASC) + WITH ORDER (c ASC) + LOCATION '../core/tests/data/window_2.csv';"; + + let plan = ctx.state().create_logical_plan(query).await?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + // Use exact matching to verify everything. Make sure during round-trip, + // information like constraints, column defaults, and other aspects of the plan are preserved. + assert_eq!(plan, logical_round_trip); + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_logical_plan_aggregation_with_pk() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.sql( + "CREATE EXTERNAL TABLE multiple_ordered_table_with_pk ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER, + primary key(c) + ) + STORED AS CSV + WITH HEADER ROW + WITH ORDER (a ASC, b ASC) + WITH ORDER (c ASC) + LOCATION '../core/tests/data/window_2.csv';", + ) + .await?; + + let query = "SELECT c, b, SUM(d) + FROM multiple_ordered_table_with_pk + GROUP BY c"; + let plan = ctx.sql(query).await?.into_optimized_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + #[tokio::test] async fn roundtrip_logical_plan_aggregation() -> Result<()> { let ctx = SessionContext::new(); @@ -211,6 +310,184 @@ async fn roundtrip_logical_plan_aggregation() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { + let ctx = SessionContext::new(); + + let input = create_csv_scan(&ctx).await?; + + let mut options = HashMap::new(); + options.insert("foo".to_string(), "bar".to_string()); + + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.csv".to_string(), + file_format: FileType::CSV, + single_file_output: true, + copy_options: CopyOptions::SQLOptions(StatementOptions::from(&options)), + }); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { + let ctx = SessionContext::new(); + + let input = create_csv_scan(&ctx).await?; + + let writer_properties = WriterProperties::builder() + .set_bloom_filter_enabled(true) + .set_created_by("DataFusion Test".to_string()) + .set_writer_version(WriterVersion::PARQUET_2_0) + .set_write_batch_size(111) + .set_data_page_size_limit(222) + .set_data_page_row_count_limit(333) + .set_dictionary_page_size_limit(444) + .set_max_row_group_size(555) + .build(); + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.parquet".to_string(), + file_format: FileType::PARQUET, + single_file_output: true, + copy_options: CopyOptions::WriterOptions(Box::new( + FileTypeWriterOptions::Parquet(ParquetWriterOptions::new(writer_properties)), + )), + }); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + match logical_round_trip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.parquet", copy_to.output_url); + assert_eq!(FileType::PARQUET, copy_to.file_format); + assert!(copy_to.single_file_output); + match ©_to.copy_options { + CopyOptions::WriterOptions(y) => match y.as_ref() { + FileTypeWriterOptions::Parquet(p) => { + let props = &p.writer_options; + assert_eq!("DataFusion Test", props.created_by()); + assert_eq!( + "PARQUET_2_0", + format!("{:?}", props.writer_version()) + ); + assert_eq!(111, props.write_batch_size()); + assert_eq!(222, props.data_page_size_limit()); + assert_eq!(333, props.data_page_row_count_limit()); + assert_eq!(444, props.dictionary_page_size_limit()); + assert_eq!(555, props.max_row_group_size()); + } + _ => panic!(), + }, + _ => panic!(), + } + } + _ => panic!(), + } + Ok(()) +} + +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { + let ctx = SessionContext::new(); + + let input = create_csv_scan(&ctx).await?; + + let writer_properties = WriterBuilder::new() + .with_delimiter(b'*') + .with_date_format("dd/MM/yyyy".to_string()) + .with_datetime_format("dd/MM/yyyy HH:mm:ss".to_string()) + .with_timestamp_format("HH:mm:ss.SSSSSS".to_string()) + .with_time_format("HH:mm:ss".to_string()) + .with_null("NIL".to_string()); + + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.csv".to_string(), + file_format: FileType::CSV, + single_file_output: true, + copy_options: CopyOptions::WriterOptions(Box::new(FileTypeWriterOptions::CSV( + CsvWriterOptions::new( + writer_properties, + CompressionTypeVariant::UNCOMPRESSED, + ), + ))), + }); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + match logical_round_trip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.csv", copy_to.output_url); + assert_eq!(FileType::CSV, copy_to.file_format); + assert!(copy_to.single_file_output); + match ©_to.copy_options { + CopyOptions::WriterOptions(y) => match y.as_ref() { + FileTypeWriterOptions::CSV(p) => { + let props = &p.writer_options; + assert_eq!(b'*', props.delimiter()); + assert_eq!("dd/MM/yyyy", props.date_format().unwrap()); + assert_eq!( + "dd/MM/yyyy HH:mm:ss", + props.datetime_format().unwrap() + ); + assert_eq!("HH:mm:ss.SSSSSS", props.timestamp_format().unwrap()); + assert_eq!("HH:mm:ss", props.time_format().unwrap()); + assert_eq!("NIL", props.null()); + } + _ => panic!(), + }, + _ => panic!(), + } + } + _ => panic!(), + } + + Ok(()) +} +async fn create_csv_scan(ctx: &SessionContext) -> Result { + ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) + .await?; + + let input = ctx.table("t1").await?.into_optimized_plan()?; + Ok(input) +} + +#[tokio::test] +async fn roundtrip_logical_plan_distinct_on() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(15, 2), true), + ]); + + ctx.register_csv( + "t1", + "tests/testdata/test.csv", + CsvReadOptions::default().schema(&schema), + ) + .await?; + + let query = "SELECT DISTINCT ON (a % 2) a, b * 2 FROM t1 ORDER BY a % 2 DESC, b"; + let plan = ctx.sql(query).await?.into_optimized_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + #[tokio::test] async fn roundtrip_single_count_distinct() -> Result<()> { let ctx = SessionContext::new(); @@ -424,59 +701,6 @@ fn scalar_values_error_serialization() { Some(vec![]), vec![Field::new("item", DataType::Int16, true)].into(), ), - // Should fail due to inconsistent types in the list - ScalarValue::new_list( - Some(vec![ - ScalarValue::Int16(None), - ScalarValue::Float32(Some(32.0)), - ]), - DataType::List(new_arc_field("item", DataType::Int16, true)), - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(None), - ScalarValue::Float32(Some(32.0)), - ]), - DataType::List(new_arc_field("item", DataType::Int16, true)), - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(None), - ScalarValue::Float32(Some(32.0)), - ]), - DataType::Int16, - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - None, - DataType::List(new_arc_field("level2", DataType::Float32, true)), - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(Some(-213.1)), - ScalarValue::Float32(None), - ScalarValue::Float32(Some(5.5)), - ScalarValue::Float32(Some(2.0)), - ScalarValue::Float32(Some(1.0)), - ]), - DataType::List(new_arc_field("level2", DataType::Float32, true)), - ), - ScalarValue::new_list( - None, - DataType::List(new_arc_field( - "lists are typed inconsistently", - DataType::Int16, - true, - )), - ), - ]), - DataType::List(new_arc_field( - "level1", - DataType::List(new_arc_field("level2", DataType::Float32, true)), - true, - )), - ), ]; for test_case in should_fail_on_seralize.into_iter() { @@ -511,7 +735,8 @@ fn round_trip_scalar_values() { ScalarValue::UInt64(None), ScalarValue::Utf8(None), ScalarValue::LargeUtf8(None), - ScalarValue::new_list(None, DataType::Boolean), + ScalarValue::List(ScalarValue::new_list(&[], &DataType::Boolean)), + ScalarValue::LargeList(ScalarValue::new_large_list(&[], &DataType::Boolean)), ScalarValue::Date32(None), ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), @@ -602,35 +827,72 @@ fn round_trip_scalar_values() { i64::MAX, ))), ScalarValue::IntervalMonthDayNano(None), - ScalarValue::new_list( - Some(vec![ + ScalarValue::List(ScalarValue::new_list( + &[ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), ScalarValue::Float32(Some(5.5)), ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), - ]), - DataType::Float32, - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list(None, DataType::Float32), - ScalarValue::new_list( - Some(vec![ + ], + &DataType::Float32, + )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), + ScalarValue::List(ScalarValue::new_list( + &[ + ScalarValue::List(ScalarValue::new_list(&[], &DataType::Float32)), + ScalarValue::List(ScalarValue::new_list( + &[ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), ScalarValue::Float32(Some(5.5)), ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), - ]), - DataType::Float32, - ), - ]), - DataType::List(new_arc_field("item", DataType::Float32, true)), - ), + ], + &DataType::Float32, + )), + ], + &DataType::List(new_arc_field("item", DataType::Float32, true)), + )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::LargeList(ScalarValue::new_large_list( + &[], + &DataType::Float32, + )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), + ], + &DataType::LargeList(new_arc_field("item", DataType::Float32, true)), + )), + ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::from_iter_primitive::< + Int32Type, + _, + _, + >( + vec![Some(vec![Some(1), Some(2), Some(3)])], + 3, + ))), ScalarValue::Dictionary( Box::new(DataType::Int32), - Box::new(ScalarValue::Utf8(Some("foo".into()))), + Box::new(ScalarValue::from("foo")), ), ScalarValue::Dictionary( Box::new(DataType::Int32), @@ -871,6 +1133,45 @@ fn round_trip_datatype() { } } +#[test] +fn roundtrip_dict_id() -> Result<()> { + let dict_id = 42; + let field = Field::new( + "keys", + DataType::List(Arc::new(Field::new_dict( + "item", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + true, + dict_id, + false, + ))), + false, + ); + let schema = Arc::new(Schema::new(vec![field])); + + // encode + let mut buf: Vec = vec![]; + let schema_proto: datafusion_proto::generated::datafusion::Schema = + schema.try_into().unwrap(); + schema_proto.encode(&mut buf).unwrap(); + + // decode + let schema_proto = + datafusion_proto::generated::datafusion::Schema::decode(buf.as_slice()).unwrap(); + let decoded: Schema = (&schema_proto).try_into()?; + + // assert + let keys = decoded.fields().iter().last().unwrap(); + match keys.data_type() { + DataType::List(field) => { + assert_eq!(field.dict_id(), Some(dict_id), "dict_id should be retained"); + } + _ => panic!("Invalid type"), + } + + Ok(()) +} + #[test] fn roundtrip_null_scalar_values() { let test_types = vec![ @@ -890,7 +1191,6 @@ fn roundtrip_null_scalar_values() { ScalarValue::Date32(None), ScalarValue::TimestampMicrosecond(None, None), ScalarValue::TimestampNanosecond(None, None), - ScalarValue::List(None, Arc::new(Field::new("item", DataType::Boolean, false))), ]; for test_case in test_types.into_iter() { @@ -1112,7 +1412,17 @@ fn roundtrip_inlist() { #[test] fn roundtrip_wildcard() { - let test_expr = Expr::Wildcard; + let test_expr = Expr::Wildcard { qualifier: None }; + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_qualified_wildcard() { + let test_expr = Expr::Wildcard { + qualifier: Some("foo".into()), + }; let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -1266,9 +1576,10 @@ fn roundtrip_aggregate_udf() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr = Expr::AggregateUDF(expr::AggregateUDF::new( + let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(dummy_agg.clone()), vec![lit(1.0_f64)], + false, Some(Box::new(lit(true))), None, )); @@ -1293,7 +1604,10 @@ fn roundtrip_scalar_udf() { scalar_fn, ); - let test_expr = Expr::ScalarUDF(ScalarUDF::new(Arc::new(udf.clone()), vec![lit("")])); + let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf.clone()), + vec![lit("")], + )); let ctx = SessionContext::new(); ctx.register_udf(udf); @@ -1351,8 +1665,8 @@ fn roundtrip_window() { // 1. without window_frame let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1362,8 +1676,8 @@ fn roundtrip_window() { // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1379,8 +1693,8 @@ fn roundtrip_window() { }; let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1396,7 +1710,7 @@ fn roundtrip_window() { }; let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("col1")], vec![col("col1")], vec![col("col2")], @@ -1447,7 +1761,7 @@ fn roundtrip_window() { ); let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateUDF(Arc::new(dummy_agg.clone())), + WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], vec![col("col1")], vec![col("col2")], @@ -1473,30 +1787,56 @@ fn roundtrip_window() { } } - fn return_type(arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return plan_err!( - "dummy_udwf expects 1 argument, got {}: {:?}", - arg_types.len(), - arg_types - ); + #[derive(Debug, Clone)] + struct SimpleWindowUDF { + signature: Signature, + } + + impl SimpleWindowUDF { + fn new() -> Self { + let signature = + Signature::exact(vec![DataType::Float64], Volatility::Immutable); + Self { signature } + } + } + + impl WindowUDFImpl for SimpleWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "dummy_udwf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return plan_err!( + "dummy_udwf expects 1 argument, got {}: {:?}", + arg_types.len(), + arg_types + ); + } + Ok(arg_types[0].clone()) + } + + fn partition_evaluator(&self) -> Result> { + make_partition_evaluator() } - Ok(Arc::new(arg_types[0].clone())) } fn make_partition_evaluator() -> Result> { Ok(Box::new(DummyWindow {})) } - let dummy_window_udf = WindowUDF { - name: String::from("dummy_udwf"), - signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), - return_type: Arc::new(return_type), - partition_evaluator_factory: Arc::new(make_partition_evaluator), - }; + let dummy_window_udf = WindowUDF::from(SimpleWindowUDF::new()); let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::WindowUDF(Arc::new(dummy_window_udf.clone())), + WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], vec![col("col1")], vec![col("col2")], diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 77e77630bcb2..27ac5d122f83 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -15,45 +15,63 @@ // specific language governing permissions and limitations // under the License. +use arrow::csv::WriterBuilder; use std::ops::Deref; use std::sync::Arc; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema}; -use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::file_format::csv::CsvSink; +use datafusion::datasource::file_format::json::JsonSink; +use datafusion::datasource::file_format::parquet::ParquetSink; +use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; -use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; +use datafusion::datasource::physical_plan::{ + FileScanConfig, FileSinkConfig, ParquetExec, +}; use datafusion::execution::context::ExecutionProps; -use datafusion::logical_expr::create_udf; -use datafusion::logical_expr::{BuiltinScalarFunction, Volatility}; -use datafusion::logical_expr::{JoinType, Operator}; -use datafusion::physical_expr::expressions::GetFieldAccessExpr; -use datafusion::physical_expr::expressions::{cast, in_list}; +use datafusion::logical_expr::{ + create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, +}; +use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; -use datafusion::physical_expr::ScalarFunctionExpr; -use datafusion::physical_plan::aggregates::PhysicalGroupBy; -use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode}; +use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; +use datafusion::physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ - binary, col, like, lit, Avg, BinaryExpr, Column, DistinctCount, GetIndexedFieldExpr, - NotExpr, NthValue, PhysicalSortExpr, Sum, + binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, DistinctCount, + GetFieldAccessExpr, GetIndexedFieldExpr, NotExpr, NthValue, PhysicalSortExpr, Sum, }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::functions::make_scalar_function; -use datafusion::physical_plan::joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode}; +use datafusion::physical_plan::insert::FileSinkExec; +use datafusion::physical_plan::joins::{ + HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, +}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::windows::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec, }; -use datafusion::physical_plan::{functions, udaf}; -use datafusion::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, Statistics}; +use datafusion::physical_plan::{ + functions, udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics, +}; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; -use datafusion_common::Result; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; +use datafusion_common::file_options::json_writer::JsonWriterOptions; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::stats::Precision; +use datafusion_common::{FileTypeWriterOptions, Result}; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature, StateTypeFunction, WindowFrame, WindowFrameBound, @@ -61,7 +79,23 @@ use datafusion_expr::{ use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; +/// Perform a serde roundtrip and assert that the string representation of the before and after plans +/// are identical. Note that this often isn't sufficient to guarantee that no information is +/// lost during serde because the string representation of a plan often only shows a subset of state. fn roundtrip_test(exec_plan: Arc) -> Result<()> { + let _ = roundtrip_test_and_return(exec_plan); + Ok(()) +} + +/// Perform a serde roundtrip and assert that the string representation of the before and after plans +/// are identical. Note that this often isn't sufficient to guarantee that no information is +/// lost during serde because the string representation of a plan often only shows a subset of state. +/// +/// This version of the roundtrip_test method returns the final plan after serde so that it can be inspected +/// farther in tests. +fn roundtrip_test_and_return( + exec_plan: Arc, +) -> Result> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; let proto: protobuf::PhysicalPlanNode = @@ -72,9 +106,15 @@ fn roundtrip_test(exec_plan: Arc) -> Result<()> { .try_into_physical_plan(&ctx, runtime.deref(), &codec) .expect("from proto"); assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); - Ok(()) + Ok(result_exec_plan) } +/// Perform a serde roundtrip and assert that the string representation of the before and after plans +/// are identical. Note that this often isn't sufficient to guarantee that no information is +/// lost during serde because the string representation of a plan often only shows a subset of state. +/// +/// This version of the roundtrip_test function accepts a SessionContext, which is required when +/// performing serde on some plans. fn roundtrip_test_with_context( exec_plan: Arc, ctx: SessionContext, @@ -93,7 +133,7 @@ fn roundtrip_test_with_context( #[test] fn roundtrip_empty() -> Result<()> { - roundtrip_test(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty())))) + roundtrip_test(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))) } #[test] @@ -106,7 +146,7 @@ fn roundtrip_date_time_interval() -> Result<()> { false, ), ]); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let date_expr = col("some_date", &schema)?; let literal_expr = col("some_interval", &schema)?; let date_time_interval_expr = @@ -121,7 +161,7 @@ fn roundtrip_date_time_interval() -> Result<()> { #[test] fn roundtrip_local_limit() -> Result<()> { roundtrip_test(Arc::new(LocalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), 25, ))) } @@ -129,7 +169,7 @@ fn roundtrip_local_limit() -> Result<()> { #[test] fn roundtrip_global_limit() -> Result<()> { roundtrip_test(Arc::new(GlobalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), 0, Some(25), ))) @@ -138,7 +178,7 @@ fn roundtrip_global_limit() -> Result<()> { #[test] fn roundtrip_global_skip_no_limit() -> Result<()> { roundtrip_test(Arc::new(GlobalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), 10, None, // no limit ))) @@ -168,8 +208,8 @@ fn roundtrip_hash_join() -> Result<()> { ] { for partition_mode in &[PartitionMode::Partitioned, PartitionMode::CollectLeft] { roundtrip_test(Arc::new(HashJoinExec::try_new( - Arc::new(EmptyExec::new(false, schema_left.clone())), - Arc::new(EmptyExec::new(false, schema_right.clone())), + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), on.clone(), None, join_type, @@ -200,8 +240,8 @@ fn roundtrip_nested_loop_join() -> Result<()> { JoinType::RightSemi, ] { roundtrip_test(Arc::new(NestedLoopJoinExec::try_new( - Arc::new(EmptyExec::new(false, schema_left.clone())), - Arc::new(EmptyExec::new(false, schema_right.clone())), + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), None, join_type, )?))?; @@ -222,21 +262,21 @@ fn roundtrip_window() -> Result<()> { }; let builtin_window_expr = Arc::new(BuiltInWindowExpr::new( - Arc::new(NthValue::first( - "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", - col("a", &schema)?, - DataType::Int64, - )), - &[col("b", &schema)?], - &[PhysicalSortExpr { - expr: col("a", &schema)?, - options: SortOptions { - descending: false, - nulls_first: false, - }, - }], - Arc::new(window_frame), - )); + Arc::new(NthValue::first( + "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", + col("a", &schema)?, + DataType::Int64, + )), + &[col("b", &schema)?], + &[PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + Arc::new(window_frame), + )); let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( Arc::new(Avg::new( @@ -266,7 +306,7 @@ fn roundtrip_window() -> Result<()> { Arc::new(window_frame), )); - let input = Arc::new(EmptyExec::new(false, schema.clone())); + let input = Arc::new(EmptyExec::new(schema.clone())); roundtrip_test(Arc::new(WindowAggExec::try_new( vec![ @@ -275,7 +315,6 @@ fn roundtrip_window() -> Result<()> { sliding_aggr_window_expr, ], input, - schema.clone(), vec![col("b", &schema)?], )?)) } @@ -300,8 +339,7 @@ fn rountrip_aggregate() -> Result<()> { PhysicalGroupBy::new_single(groups.clone()), aggregates.clone(), vec![None], - vec![None], - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), schema, )?)) } @@ -368,8 +406,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { PhysicalGroupBy::new_single(groups.clone()), aggregates.clone(), vec![None], - vec![None], - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), schema, )?), ctx, @@ -395,7 +432,7 @@ fn roundtrip_filter_with_not_and_in_list() -> Result<()> { let and = binary(not, Operator::And, in_list, &schema)?; roundtrip_test(Arc::new(FilterExec::try_new( and, - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), )?)) } @@ -422,7 +459,7 @@ fn roundtrip_sort() -> Result<()> { ]; roundtrip_test(Arc::new(SortExec::new( sort_exprs, - Arc::new(EmptyExec::new(false, schema)), + Arc::new(EmptyExec::new(schema)), ))) } @@ -450,11 +487,11 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { roundtrip_test(Arc::new(SortExec::new( sort_exprs.clone(), - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), )))?; roundtrip_test(Arc::new( - SortExec::new(sort_exprs, Arc::new(EmptyExec::new(false, schema))) + SortExec::new(sort_exprs, Arc::new(EmptyExec::new(schema))) .with_preserve_partitioning(true), )) } @@ -473,16 +510,16 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { 1024, )]], statistics: Statistics { - num_rows: Some(100), - total_byte_size: Some(1024), - column_statistics: None, - is_exact: false, + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(1024), + column_statistics: Statistics::unknown_column(&Arc::new(Schema::new(vec![ + Field::new("col", DataType::Utf8, false), + ]))), }, projection: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let predicate = Arc::new(BinaryExpr::new( @@ -503,7 +540,7 @@ fn roundtrip_builtin_scalar_function() -> Result<()> { let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let input = Arc::new(EmptyExec::new(false, schema.clone())); + let input = Arc::new(EmptyExec::new(schema.clone())); let execution_props = ExecutionProps::new(); @@ -514,7 +551,7 @@ fn roundtrip_builtin_scalar_function() -> Result<()> { "acos", fun_expr, vec![col("a", &schema)?], - &DataType::Int64, + DataType::Int64, None, ); @@ -530,7 +567,7 @@ fn roundtrip_scalar_udf() -> Result<()> { let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let input = Arc::new(EmptyExec::new(false, schema.clone())); + let input = Arc::new(EmptyExec::new(schema.clone())); let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); @@ -548,7 +585,7 @@ fn roundtrip_scalar_udf() -> Result<()> { "dummy", scalar_fn, vec![col("a", &schema)?], - &DataType::Int64, + DataType::Int64, None, ); @@ -582,8 +619,7 @@ fn roundtrip_distinct_count() -> Result<()> { PhysicalGroupBy::new_single(groups), aggregates.clone(), vec![None], - vec![None], - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), schema, )?)) } @@ -594,7 +630,7 @@ fn roundtrip_like() -> Result<()> { Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Utf8, false), ]); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let like_expr = like( false, false, @@ -621,13 +657,13 @@ fn roundtrip_get_indexed_field_named_struct_field() -> Result<()> { ]; let schema = Schema::new(fields); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let col_arg = col("arg", &schema)?; let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new( col_arg, GetFieldAccessExpr::NamedStructField { - name: ScalarValue::Utf8(Some(String::from("name"))), + name: ScalarValue::from("name"), }, )); @@ -648,7 +684,7 @@ fn roundtrip_get_indexed_field_list_index() -> Result<()> { ]; let schema = Schema::new(fields); - let input = Arc::new(EmptyExec::new(true, Arc::new(schema.clone()))); + let input = Arc::new(PlaceholderRowExec::new(Arc::new(schema.clone()))); let col_arg = col("arg", &schema)?; let col_key = col("key", &schema)?; @@ -675,7 +711,7 @@ fn roundtrip_get_indexed_field_list_range() -> Result<()> { ]; let schema = Schema::new(fields); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let col_arg = col("arg", &schema)?; let col_start = col("start", &schema)?; @@ -697,11 +733,11 @@ fn roundtrip_get_indexed_field_list_range() -> Result<()> { } #[test] -fn rountrip_analyze() -> Result<()> { +fn roundtrip_analyze() -> Result<()> { let field_a = Field::new("plan_type", DataType::Utf8, false); let field_b = Field::new("plan", DataType::Utf8, false); let schema = Schema::new(vec![field_a, field_b]); - let input = Arc::new(EmptyExec::new(true, Arc::new(schema.clone()))); + let input = Arc::new(PlaceholderRowExec::new(Arc::new(schema.clone()))); roundtrip_test(Arc::new(AnalyzeExec::new( false, @@ -710,3 +746,207 @@ fn rountrip_analyze() -> Result<()> { Arc::new(schema), ))) } + +#[test] +fn roundtrip_json_sink() -> Result<()> { + let field_a = Field::new("plan_type", DataType::Utf8, false); + let field_b = Field::new("plan", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let input = Arc::new(PlaceholderRowExec::new(schema.clone())); + + let file_sink_config = FileSinkConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], + single_file_output: true, + overwrite: true, + file_type_writer_options: FileTypeWriterOptions::JSON(JsonWriterOptions::new( + CompressionTypeVariant::UNCOMPRESSED, + )), + }; + let data_sink = Arc::new(JsonSink::new(file_sink_config)); + let sort_order = vec![PhysicalSortRequirement::new( + Arc::new(Column::new("plan_type", 0)), + Some(SortOptions { + descending: true, + nulls_first: false, + }), + )]; + + roundtrip_test(Arc::new(FileSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + ))) +} + +#[test] +fn roundtrip_csv_sink() -> Result<()> { + let field_a = Field::new("plan_type", DataType::Utf8, false); + let field_b = Field::new("plan", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let input = Arc::new(PlaceholderRowExec::new(schema.clone())); + + let file_sink_config = FileSinkConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], + single_file_output: true, + overwrite: true, + file_type_writer_options: FileTypeWriterOptions::CSV(CsvWriterOptions::new( + WriterBuilder::default(), + CompressionTypeVariant::ZSTD, + )), + }; + let data_sink = Arc::new(CsvSink::new(file_sink_config)); + let sort_order = vec![PhysicalSortRequirement::new( + Arc::new(Column::new("plan_type", 0)), + Some(SortOptions { + descending: true, + nulls_first: false, + }), + )]; + + let roundtrip_plan = roundtrip_test_and_return(Arc::new(FileSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + ))) + .unwrap(); + + let roundtrip_plan = roundtrip_plan + .as_any() + .downcast_ref::() + .unwrap(); + let csv_sink = roundtrip_plan + .sink() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + CompressionTypeVariant::ZSTD, + csv_sink + .config() + .file_type_writer_options + .try_into_csv() + .unwrap() + .compression + ); + + Ok(()) +} + +#[test] +fn roundtrip_parquet_sink() -> Result<()> { + let field_a = Field::new("plan_type", DataType::Utf8, false); + let field_b = Field::new("plan", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let input = Arc::new(PlaceholderRowExec::new(schema.clone())); + + let file_sink_config = FileSinkConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], + single_file_output: true, + overwrite: true, + file_type_writer_options: FileTypeWriterOptions::Parquet( + ParquetWriterOptions::new(WriterProperties::default()), + ), + }; + let data_sink = Arc::new(ParquetSink::new(file_sink_config)); + let sort_order = vec![PhysicalSortRequirement::new( + Arc::new(Column::new("plan_type", 0)), + Some(SortOptions { + descending: true, + nulls_first: false, + }), + )]; + + roundtrip_test(Arc::new(FileSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + ))) +} + +#[test] +fn roundtrip_sym_hash_join() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let on = vec![( + Column::new("col", schema_left.index_of("col")?), + Column::new("col", schema_right.index_of("col")?), + )]; + + let schema_left = Arc::new(schema_left); + let schema_right = Arc::new(schema_right); + for join_type in &[ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::LeftSemi, + JoinType::RightSemi, + ] { + for partition_mode in &[ + StreamJoinPartitionMode::Partitioned, + StreamJoinPartitionMode::SinglePartition, + ] { + roundtrip_test(Arc::new( + datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new( + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), + on.clone(), + None, + join_type, + false, + *partition_mode, + )?, + ))?; + } + } + Ok(()) +} + +#[test] +fn roundtrip_union() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let left = EmptyExec::new(Arc::new(schema_left)); + let right = EmptyExec::new(Arc::new(schema_right)); + let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; + let union = UnionExec::new(inputs); + roundtrip_test(Arc::new(union)) +} + +#[test] +fn roundtrip_interleave() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let partition = Partitioning::Hash(vec![], 3); + let left = RepartitionExec::try_new( + Arc::new(EmptyExec::new(Arc::new(schema_left))), + partition.clone(), + )?; + let right = RepartitionExec::try_new( + Arc::new(EmptyExec::new(Arc::new(schema_right))), + partition.clone(), + )?; + let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; + let interleave = InterleaveExec::try_new(inputs)?; + roundtrip_test(Arc::new(interleave)) +} diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index f32c81527925..5b890accd81f 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -128,6 +128,12 @@ fn exact_roundtrip_linearized_binary_expr() { } } +#[test] +fn roundtrip_qualified_alias() { + let qual_alias = col("c1").alias_qualified(Some("my_table"), "my_column"); + assert_eq!(qual_alias, roundtrip_expr(&qual_alias)); +} + #[test] fn roundtrip_deeply_nested_binary_expr() { // We need more stack space so this doesn't overflow in dev builds diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index c2cdc4c52dbd..b91a2ac1fbd7 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -19,9 +19,9 @@ name = "datafusion-sql" description = "DataFusion SQL Query Planner" keywords = ["datafusion", "sql", "parser", "planner"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -39,13 +39,13 @@ unicode_expressions = [] [dependencies] arrow = { workspace = true } arrow-schema = { workspace = true } -datafusion-common = { path = "../common", version = "31.0.0", default-features = false } -datafusion-expr = { path = "../expr", version = "31.0.0" } -log = "^0.4" +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } +log = { workspace = true } sqlparser = { workspace = true } [dev-dependencies] -ctor = "0.2.0" -env_logger = "0.10" +ctor = { workspace = true } +env_logger = { workspace = true } paste = "^1.0" rstest = "0.18" diff --git a/datafusion/sql/README.md b/datafusion/sql/README.md index 2ad994e4eba5..256fa774b410 100644 --- a/datafusion/sql/README.md +++ b/datafusion/sql/README.md @@ -20,7 +20,7 @@ # DataFusion SQL Query Planner This crate provides a general purpose SQL query planner that can parse SQL and translate queries into logical -plans. Although this crate is used by the [DataFusion](df) query engine, it was designed to be easily usable from any +plans. Although this crate is used by the [DataFusion][df] query engine, it was designed to be easily usable from any project that requires a SQL query planner and does not make any assumptions about how the resulting logical plan will be translated to a physical plan. For example, there is no concept of row-based versus columnar execution in the logical plan. diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 8a12cc32b641..9df65b99a748 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -49,20 +49,20 @@ fn main() { let statement = &ast[0]; // create a logical query plan - let schema_provider = MySchemaProvider::new(); - let sql_to_rel = SqlToRel::new(&schema_provider); + let context_provider = MyContextProvider::new(); + let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); // show the plan println!("{plan:?}"); } -struct MySchemaProvider { +struct MyContextProvider { options: ConfigOptions, tables: HashMap>, } -impl MySchemaProvider { +impl MyContextProvider { fn new() -> Self { let mut tables = HashMap::new(); tables.insert( @@ -104,8 +104,8 @@ fn create_table_source(fields: Vec) -> Arc { ))) } -impl ContextProvider for MySchemaProvider { - fn get_table_provider(&self, name: TableReference) -> Result> { +impl ContextProvider for MyContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { match self.tables.get(name.table()) { Some(table) => Ok(table.clone()), _ => plan_err!("Table not found: {}", name.table()), diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index 549d46c5e277..ade8b96b5cc2 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -21,7 +21,9 @@ use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc}; use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit}; -use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + plan_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, +}; use datafusion_common::plan_err; use datafusion_expr::{Expr, ExprSchemable}; @@ -98,9 +100,7 @@ pub fn parse_data_type(val: &str) -> Result { } fn make_error(val: &str, msg: &str) -> DataFusionError { - DataFusionError::Plan( - format!("Unsupported type '{val}'. Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'. Error {msg}" ) - ) + plan_datafusion_err!("Unsupported type '{val}'. Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'. Error {msg}" ) } fn make_error_expected(val: &str, expected: &Token, actual: &Token) -> DataFusionError { @@ -149,6 +149,7 @@ impl<'a> Parser<'a> { Token::Decimal256 => self.parse_decimal_256(), Token::Dictionary => self.parse_dictionary(), Token::List => self.parse_list(), + Token::LargeList => self.parse_large_list(), tok => Err(make_error( self.val, &format!("finding next type, got unexpected '{tok}'"), @@ -166,6 +167,16 @@ impl<'a> Parser<'a> { )))) } + /// Parses the LargeList type + fn parse_large_list(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let data_type = self.parse_next_type()?; + self.expect_token(Token::RParen)?; + Ok(DataType::LargeList(Arc::new(Field::new( + "item", data_type, true, + )))) + } + /// Parses the next timeunit fn parse_time_unit(&mut self, context: &str) -> Result { match self.next_token()? { @@ -496,6 +507,7 @@ impl<'a> Tokenizer<'a> { "Date64" => Token::SimpleType(DataType::Date64), "List" => Token::List, + "LargeList" => Token::LargeList, "Second" => Token::TimeUnit(TimeUnit::Second), "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond), @@ -585,6 +597,7 @@ enum Token { Integer(i64), DoubleQuotedString(String), List, + LargeList, } impl Display for Token { @@ -592,6 +605,7 @@ impl Display for Token { match self { Token::SimpleType(t) => write!(f, "{t}"), Token::List => write!(f, "List"), + Token::LargeList => write!(f, "LargeList"), Token::Timestamp => write!(f, "Timestamp"), Token::Time32 => write!(f, "Time32"), Token::Time64 => write!(f, "Time64"), diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 05f80fcfafa9..395f10b6f783 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -16,13 +16,15 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{not_impl_err, plan_err, DFSchema, DataFusionError, Result}; -use datafusion_expr::expr::{ScalarFunction, ScalarUDF}; +use datafusion_common::{ + not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, Result, +}; +use datafusion_expr::expr::ScalarFunction; use datafusion_expr::function::suggest_valid_function; -use datafusion_expr::window_frame::regularize; +use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, - WindowFunction, + expr, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, + WindowFunctionDefinition, }; use sqlparser::ast::{ Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, WindowType, @@ -34,75 +36,97 @@ use super::arrow_cast::ARROW_CAST_NAME; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn sql_function_to_expr( &self, - mut function: SQLFunction, + function: SQLFunction, schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let name = if function.name.0.len() > 1 { + let SQLFunction { + name, + args, + over, + distinct, + filter, + null_treatment, + special: _, // true if not called with trailing parens + order_by, + } = function; + + if let Some(null_treatment) = null_treatment { + return not_impl_err!("Null treatment in aggregate functions is not supported: {null_treatment}"); + } + + let name = if name.0.len() > 1 { // DF doesn't handle compound identifiers // (e.g. "foo.bar") for function names yet - function.name.to_string() + name.to_string() } else { - crate::utils::normalize_ident(function.name.0[0].clone()) + crate::utils::normalize_ident(name.0[0].clone()) }; // user-defined function (UDF) should have precedence in case it has the same name as a scalar built-in function - if let Some(fm) = self.schema_provider.get_function_meta(&name) { - let args = - self.function_args_to_expr(function.args, schema, planner_context)?; - return Ok(Expr::ScalarUDF(ScalarUDF::new(fm, args))); + if let Some(fm) = self.context_provider.get_function_meta(&name) { + let args = self.function_args_to_expr(args, schema, planner_context)?; + return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fm, args))); } // next, scalar built-in if let Ok(fun) = BuiltinScalarFunction::from_str(&name) { - let args = - self.function_args_to_expr(function.args, schema, planner_context)?; + let args = self.function_args_to_expr(args, schema, planner_context)?; return Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))); }; // If function is a window function (it has an OVER clause), // it shouldn't have ordering requirement as function argument // required ordering should be defined in OVER clause. - let is_function_window = function.over.is_some(); - if !function.order_by.is_empty() && is_function_window { + let is_function_window = over.is_some(); + if !order_by.is_empty() && is_function_window { return plan_err!( "Aggregate ORDER BY is not implemented for window functions" ); } // then, window function - if let Some(WindowType::WindowSpec(window)) = function.over.take() { + if let Some(WindowType::WindowSpec(window)) = over { let partition_by = window .partition_by .into_iter() + // ignore window spec PARTITION BY for scalar values + // as they do not change and thus do not generate new partitions + .filter(|e| !matches!(e, sqlparser::ast::Expr::Value { .. },)) .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; - let order_by = - self.order_by_to_sort_expr(&window.order_by, schema, planner_context)?; + let mut order_by = self.order_by_to_sort_expr( + &window.order_by, + schema, + planner_context, + // Numeric literals in window function ORDER BY are treated as constants + false, + )?; let window_frame = window .window_frame .as_ref() .map(|window_frame| { let window_frame = window_frame.clone().try_into()?; - regularize(window_frame, order_by.len()) + check_window_frame(&window_frame, order_by.len()) + .map(|_| window_frame) }) .transpose()?; + let window_frame = if let Some(window_frame) = window_frame { + regularize_window_order_by(&window_frame, &mut order_by)?; window_frame } else { WindowFrame::new(!order_by.is_empty()) }; + if let Ok(fun) = self.find_window_func(&name) { let expr = match fun { - WindowFunction::AggregateFunction(aggregate_fun) => { - let args = self.function_args_to_expr( - function.args, - schema, - planner_context, - )?; + WindowFunctionDefinition::AggregateFunction(aggregate_fun) => { + let args = + self.function_args_to_expr(args, schema, planner_context)?; Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(aggregate_fun), + WindowFunctionDefinition::AggregateFunction(aggregate_fun), args, partition_by, order_by, @@ -111,11 +135,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } _ => Expr::WindowFunction(expr::WindowFunction::new( fun, - self.function_args_to_expr( - function.args, - schema, - planner_context, - )?, + self.function_args_to_expr(args, schema, planner_context)?, partition_by, order_by, window_frame, @@ -124,36 +144,33 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return Ok(expr); } } else { + // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function + if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { + let args = self.function_args_to_expr(args, schema, planner_context)?; + return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( + fm, args, false, None, None, + ))); + } + // next, aggregate built-ins if let Ok(fun) = AggregateFunction::from_str(&name) { - let distinct = function.distinct; - let order_by = self.order_by_to_sort_expr( - &function.order_by, - schema, - planner_context, - )?; + let order_by = + self.order_by_to_sort_expr(&order_by, schema, planner_context, true)?; let order_by = (!order_by.is_empty()).then_some(order_by); - let args = - self.function_args_to_expr(function.args, schema, planner_context)?; + let args = self.function_args_to_expr(args, schema, planner_context)?; + let filter: Option> = filter + .map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context)) + .transpose()? + .map(Box::new); return Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, args, distinct, None, order_by, + fun, args, distinct, filter, order_by, ))); }; - // User defined aggregate functions (UDAF) - if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) { - let args = - self.function_args_to_expr(function.args, schema, planner_context)?; - return Ok(Expr::AggregateUDF(expr::AggregateUDF::new( - fm, args, None, None, - ))); - } - // Special case arrow_cast (as its type is dependent on its argument value) if name == ARROW_CAST_NAME { - let args = - self.function_args_to_expr(function.args, schema, planner_context)?; + let args = self.function_args_to_expr(args, schema, planner_context)?; return super::arrow_cast::create_arrow_cast(args, schema); } } @@ -174,22 +191,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) } - pub(super) fn find_window_func(&self, name: &str) -> Result { - window_function::find_df_window_func(name) + pub(super) fn find_window_func( + &self, + name: &str, + ) -> Result { + expr::find_df_window_func(name) // next check user defined aggregates .or_else(|| { - self.schema_provider + self.context_provider .get_aggregate_meta(name) - .map(WindowFunction::AggregateUDF) + .map(WindowFunctionDefinition::AggregateUDF) }) // next check user defined window functions .or_else(|| { - self.schema_provider + self.context_provider .get_window_meta(name) - .map(WindowFunction::WindowUDF) + .map(WindowFunctionDefinition::WindowUDF) }) .ok_or_else(|| { - DataFusionError::Plan(format!("There is no window function named {name}")) + plan_datafusion_err!("There is no window function named {name}") }) } @@ -207,11 +227,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { FunctionArg::Named { name: _, arg: FunctionArgExpr::Wildcard, - } => Ok(Expr::Wildcard), + } => Ok(Expr::Wildcard { qualifier: None }), FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { self.sql_expr_to_logical_expr(arg, schema, planner_context) } - FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(Expr::Wildcard), + FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { + Ok(Expr::Wildcard { qualifier: None }) + } _ => not_impl_err!("Unsupported qualified wildcard argument: {sql:?}"), } } diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 5e03f14e5337..9f53ff579e7c 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -17,7 +17,8 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ - internal_err, Column, DFField, DFSchema, DataFusionError, Result, TableReference, + internal_err, plan_datafusion_err, Column, DFField, DFSchema, DataFusionError, + Result, TableReference, }; use datafusion_expr::{Case, Expr}; use sqlparser::ast::{Expr as SQLExpr, Ident}; @@ -33,12 +34,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // TODO: figure out if ScalarVariables should be insensitive. let var_names = vec![id.value]; let ty = self - .schema_provider + .context_provider .get_variable_type(&var_names) .ok_or_else(|| { - DataFusionError::Plan(format!( - "variable {var_names:?} has no type information" - )) + plan_datafusion_err!("variable {var_names:?} has no type information") })?; Ok(Expr::ScalarVariable(ty, var_names)) } else { @@ -99,7 +98,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|id| self.normalizer.normalize(id)) .collect(); let ty = self - .schema_provider + .context_provider .get_variable_type(&var_names) .ok_or_else(|| { DataFusionError::Execution(format!( diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index be6cee9885aa..9fded63af3fc 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -29,13 +29,14 @@ mod value; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::DataType; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use arrow_schema::TimeUnit; use datafusion_common::{ internal_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::AggregateFunctionDefinition; +use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::expr::{InList, Placeholder}; use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast, Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast, @@ -97,11 +98,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { StackEntry::Operator(op) => { let right = eval_stack.pop().unwrap(); let left = eval_stack.pop().unwrap(); + let expr = Expr::BinaryExpr(BinaryExpr::new( Box::new(left), op, Box::new(right), )); + eval_stack.push(expr); } } @@ -122,7 +125,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut expr = self.sql_expr_to_logical_expr(sql, schema, planner_context)?; expr = self.rewrite_partial_qualifier(expr, schema); self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?; - let expr = infer_placeholder_types(expr, schema)?; + let expr = expr.infer_placeholder_types(schema)?; Ok(expr) } @@ -170,7 +173,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new( BuiltinScalarFunction::DatePart, vec![ - Expr::Literal(ScalarValue::Utf8(Some(format!("{field}")))), + Expr::Literal(ScalarValue::from(format!("{field}"))), self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ], ))) @@ -223,16 +226,33 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context, ), - SQLExpr::Cast { expr, data_type } => Ok(Expr::Cast(Cast::new( - Box::new(self.sql_expr_to_logical_expr( - *expr, - schema, - planner_context, - )?), - self.convert_data_type(&data_type)?, - ))), + SQLExpr::Cast { + expr, data_type, .. + } => { + let dt = self.convert_data_type(&data_type)?; + let expr = + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; + + // numeric constants are treated as seconds (rather as nanoseconds) + // to align with postgres / duckdb semantics + let expr = match &dt { + DataType::Timestamp(TimeUnit::Nanosecond, tz) + if expr.get_type(schema)? == DataType::Int64 => + { + Expr::Cast(Cast::new( + Box::new(expr), + DataType::Timestamp(TimeUnit::Second, tz.clone()), + )) + } + _ => expr, + }; - SQLExpr::TryCast { expr, data_type } => Ok(Expr::TryCast(TryCast::new( + Ok(Expr::Cast(Cast::new(Box::new(expr), dt))) + } + + SQLExpr::TryCast { + expr, data_type, .. + } => Ok(Expr::TryCast(TryCast::new( Box::new(self.sql_expr_to_logical_expr( *expr, schema, @@ -413,6 +433,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { expr, trim_where, trim_what, + .. } => self.sql_trim_to_expr( *expr, trim_where, @@ -455,7 +476,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, planner_context, ), - + SQLExpr::Overlay { + expr, + overlay_what, + overlay_from, + overlay_for, + } => self.sql_overlay_to_expr( + *expr, + *overlay_what, + *overlay_from, + overlay_for, + schema, + planner_context, + ), SQLExpr::Nested(e) => { self.sql_expr_to_logical_expr(*e, schema, planner_context) } @@ -478,10 +511,36 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.parse_array_agg(array_agg, schema, planner_context) } + SQLExpr::Struct { values, fields } => { + self.parse_struct(values, fields, schema, planner_context) + } + _ => not_impl_err!("Unsupported ast node in sqltorel: {sql:?}"), } } + fn parse_struct( + &self, + values: Vec, + fields: Vec, + input_schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + if !fields.is_empty() { + return not_impl_err!("Struct fields are not supported yet"); + } + let args = values + .into_iter() + .map(|value| { + self.sql_expr_to_logical_expr(value, input_schema, planner_context) + }) + .collect::>>()?; + Ok(Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::Struct, + args, + ))) + } + fn parse_array_agg( &self, array_agg: ArrayAgg, @@ -498,7 +557,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } = array_agg; let order_by = if let Some(order_by) = order_by { - Some(self.order_by_to_sort_expr(&order_by, input_schema, planner_context)?) + Some(self.order_by_to_sort_expr( + &order_by, + input_schema, + planner_context, + true, + )?) } else { None }; @@ -615,6 +679,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) } + fn sql_overlay_to_expr( + &self, + expr: SQLExpr, + overlay_what: SQLExpr, + overlay_from: SQLExpr, + overlay_for: Option>, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let fun = BuiltinScalarFunction::OverLay; + let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; + let what_arg = + self.sql_expr_to_logical_expr(overlay_what, schema, planner_context)?; + let from_arg = + self.sql_expr_to_logical_expr(overlay_from, schema, planner_context)?; + let args = match overlay_for { + Some(for_expr) => { + let for_expr = + self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; + vec![arg, what_arg, from_arg, for_expr] + } + None => vec![arg, what_arg, from_arg], + }; + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + } + fn sql_agg_with_filter_to_expr( &self, expr: SQLExpr, @@ -624,7 +714,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { match self.sql_expr_to_logical_expr(expr, schema, planner_context)? { Expr::AggregateFunction(expr::AggregateFunction { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), args, distinct, order_by, @@ -656,7 +746,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value( Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), ) => GetFieldAccess::NamedStructField { - name: ScalarValue::Utf8(Some(s)), + name: ScalarValue::from(s), }, SQLExpr::JsonAccess { left, @@ -712,39 +802,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } -// modifies expr if it is a placeholder with datatype of right -fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { - if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr { - if data_type.is_none() { - let other_dt = other.get_type(schema); - match other_dt { - Err(e) => { - return Err(e.context(format!( - "Can not find type of {other} needed to infer type of {expr}" - )))?; - } - Ok(dt) => { - *data_type = Some(dt); - } - } - }; - } - Ok(()) -} - -/// Find all [`Expr::Placeholder`] tokens in a logical plan, and try -/// to infer their [`DataType`] from the context of their use. -fn infer_placeholder_types(expr: Expr, schema: &DFSchema) -> Result { - expr.transform(&|mut expr| { - // Default to assuming the arguments are the same type - if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { - rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; - rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; - }; - Ok(Transformed::Yes(expr)) - }) -} - #[cfg(test)] mod tests { use super::*; @@ -762,12 +819,12 @@ mod tests { use crate::TableReference; - struct TestSchemaProvider { + struct TestContextProvider { options: ConfigOptions, tables: HashMap>, } - impl TestSchemaProvider { + impl TestContextProvider { pub fn new() -> Self { let mut tables = HashMap::new(); tables.insert( @@ -786,11 +843,8 @@ mod tests { } } - impl ContextProvider for TestSchemaProvider { - fn get_table_provider( - &self, - name: TableReference, - ) -> Result> { + impl ContextProvider for TestContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { match self.tables.get(name.table()) { Some(table) => Ok(table.clone()), _ => plan_err!("Table not found: {}", name.table()), @@ -843,8 +897,8 @@ mod tests { .unwrap(); let sql_expr = parser.parse_expr().unwrap(); - let schema_provider = TestSchemaProvider::new(); - let sql_to_rel = SqlToRel::new(&schema_provider); + let context_provider = TestContextProvider::new(); + let sql_to_rel = SqlToRel::new(&context_provider); // Should not stack overflow sql_to_rel.sql_expr_to_logical_expr( diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index b32388f1bcdf..772255bd9773 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -16,18 +16,25 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{DFSchema, DataFusionError, Result}; +use datafusion_common::{ + plan_datafusion_err, plan_err, DFSchema, DataFusionError, Result, +}; use datafusion_expr::expr::Sort; use datafusion_expr::Expr; use sqlparser::ast::{Expr as SQLExpr, OrderByExpr, Value}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { - /// convert sql [OrderByExpr] to `Vec` + /// Convert sql [OrderByExpr] to `Vec`. + /// + /// If `literal_to_column` is true, treat any numeric literals (e.g. `2`) as a 1 based index + /// into the SELECT list (e.g. `SELECT a, b FROM table ORDER BY 2`). + /// If false, interpret numeric literals as constant values. pub(crate) fn order_by_to_sort_expr( &self, exprs: &[OrderByExpr], schema: &DFSchema, planner_context: &mut PlannerContext, + literal_to_column: bool, ) -> Result> { let mut expr_vec = vec![]; for e in exprs { @@ -38,21 +45,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } = e; let expr = match expr { - SQLExpr::Value(Value::Number(v, _)) => { + SQLExpr::Value(Value::Number(v, _)) if literal_to_column => { let field_index = v .parse::() - .map_err(|err| DataFusionError::Plan(err.to_string()))?; + .map_err(|err| plan_datafusion_err!("{}", err))?; if field_index == 0 { - return Err(DataFusionError::Plan( - "Order by index starts at 1 for column indexes".to_string(), - )); + return plan_err!( + "Order by index starts at 1 for column indexes" + ); } else if schema.fields().len() < field_index { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Order by column out of bounds, specified: {}, max: {}", field_index, schema.fields().len() - ))); + ); } let field = schema.field(field_index - 1); diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index c949904cd84c..9f88318ab21a 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -22,13 +22,14 @@ use arrow_schema::DataType; use datafusion_common::{ not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr::{BinaryExpr, Placeholder}; use datafusion_expr::{lit, Expr, Operator}; +use datafusion_expr::{BuiltinScalarFunction, ScalarFunctionDefinition}; use log::debug; use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; use sqlparser::parser::ParserError::ParserError; use std::borrow::Cow; -use std::collections::HashSet; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn parse_value( @@ -107,7 +108,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } Ok(index) => index - 1, Err(_) => { - return plan_err!("Invalid placeholder, not a number: {param}"); + return if param_data_types.is_empty() { + Ok(Expr::Placeholder(Placeholder::new(param, None))) + } else { + // when PREPARE Statement, param_data_types length is always 0 + plan_err!("Invalid placeholder, not a number: {param}") + }; } }; // Check if the placeholder is in the parameter list @@ -137,9 +143,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, &mut PlannerContext::new(), )?; + match value { - Expr::Literal(scalar) => { - values.push(scalar); + Expr::Literal(_) => { + values.push(value); + } + Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + .. + }) => { + if fun == BuiltinScalarFunction::MakeArray { + values.push(value); + } else { + return not_impl_err!( + "ScalarFunctions without MakeArray are not supported: {value}" + ); + } } _ => { return not_impl_err!( @@ -149,18 +168,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - let data_types: HashSet = - values.iter().map(|e| e.data_type()).collect(); - - if data_types.is_empty() { - Ok(lit(ScalarValue::new_list(None, DataType::Utf8))) - } else if data_types.len() > 1 { - not_impl_err!("Arrays with different types are not supported: {data_types:?}") - } else { - let data_type = values[0].data_type(); - - Ok(lit(ScalarValue::new_list(Some(values), data_type))) - } + Ok(Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + values, + ))) } /// Convert a SQL interval expression to a DataFusion logical plan @@ -332,6 +343,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // TODO make interval parsing better in arrow-rs / expose `IntervalType` fn has_units(val: &str) -> bool { + let val = val.to_lowercase(); val.ends_with("century") || val.ends_with("centuries") || val.ends_with("decade") diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 0e3e47508904..dbd72ec5eb7a 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -197,6 +197,8 @@ pub struct CreateExternalTable { pub unbounded: bool, /// Table(provider) specific options pub options: HashMap, + /// A table-level constraint + pub constraints: Vec, } impl fmt::Display for CreateExternalTable { @@ -211,13 +213,6 @@ impl fmt::Display for CreateExternalTable { } } -/// DataFusion extension DDL for `DESCRIBE TABLE` -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct DescribeTableStmt { - /// Table name - pub table_name: ObjectName, -} - /// DataFusion SQL Statement. /// /// This can either be a [`Statement`] from [`sqlparser`] from a @@ -231,8 +226,6 @@ pub enum Statement { Statement(Box), /// Extension: `CREATE EXTERNAL TABLE` CreateExternalTable(CreateExternalTable), - /// Extension: `DESCRIBE TABLE` - DescribeTableStmt(DescribeTableStmt), /// Extension: `COPY TO` CopyTo(CopyToStatement), /// EXPLAIN for extensions @@ -244,7 +237,6 @@ impl fmt::Display for Statement { match self { Statement::Statement(stmt) => write!(f, "{stmt}"), Statement::CreateExternalTable(stmt) => write!(f, "{stmt}"), - Statement::DescribeTableStmt(_) => write!(f, "DESCRIBE TABLE ..."), Statement::CopyTo(stmt) => write!(f, "{stmt}"), Statement::Explain(stmt) => write!(f, "{stmt}"), } @@ -253,8 +245,7 @@ impl fmt::Display for Statement { /// Datafusion SQL Parser based on [`sqlparser`] /// -/// Parses DataFusion's SQL dialect, often delegating to [`sqlparser`]'s -/// [`Parser`](sqlparser::parser::Parser). +/// Parses DataFusion's SQL dialect, often delegating to [`sqlparser`]'s [`Parser`]. /// /// DataFusion mostly follows existing SQL dialects via /// `sqlparser`. However, certain statements such as `COPY` and @@ -344,10 +335,6 @@ impl<'a> DFParser<'a> { self.parser.next_token(); // COPY self.parse_copy() } - Keyword::DESCRIBE => { - self.parser.next_token(); // DESCRIBE - self.parse_describe() - } Keyword::EXPLAIN => { // (TODO parse all supported statements) self.parser.next_token(); // EXPLAIN @@ -370,14 +357,6 @@ impl<'a> DFParser<'a> { } } - /// Parse a SQL `DESCRIBE` statement - pub fn parse_describe(&mut self) -> Result { - let table_name = self.parser.parse_object_name()?; - Ok(Statement::DescribeTableStmt(DescribeTableStmt { - table_name, - })) - } - /// Parse a SQL `COPY TO` statement pub fn parse_copy(&mut self) -> Result { // parse as a query @@ -630,7 +609,7 @@ impl<'a> DFParser<'a> { self.parser .parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let table_name = self.parser.parse_object_name()?; - let (columns, _) = self.parse_columns()?; + let (columns, constraints) = self.parse_columns()?; #[derive(Default)] struct Builder { @@ -749,6 +728,7 @@ impl<'a> DFParser<'a> { .unwrap_or(CompressionTypeVariant::UNCOMPRESSED), unbounded, options: builder.options.unwrap_or(HashMap::new()), + constraints, }; Ok(Statement::CreateExternalTable(create)) } @@ -900,6 +880,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -918,6 +899,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -937,6 +919,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -956,6 +939,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -975,6 +959,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -997,6 +982,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; } @@ -1024,6 +1010,7 @@ mod tests { )?, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; } @@ -1043,6 +1030,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1061,6 +1049,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1079,6 +1068,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1098,6 +1088,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1122,6 +1113,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::from([("k1".into(), "v1".into())]), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1144,6 +1136,7 @@ mod tests { ("k1".into(), "v1".into()), ("k2".into(), "v2".into()), ]), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1189,6 +1182,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; } @@ -1229,6 +1223,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1265,6 +1260,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1313,6 +1309,7 @@ mod tests { ("ROW_GROUP_SIZE".into(), "1024".into()), ("TRUNCATE".into(), "NO".into()), ]), + constraints: vec![], }); expect_parse_ok(sql, expected)?; diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index a2d790d438cc..a04df5589b85 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -21,11 +21,12 @@ use std::sync::Arc; use std::vec; use arrow_schema::*; -use datafusion_common::field_not_found; -use datafusion_common::internal_err; +use datafusion_common::{ + field_not_found, internal_err, plan_datafusion_err, SchemaError, +}; use datafusion_expr::WindowUDF; -use sqlparser::ast::ExactNumberInfo; use sqlparser::ast::TimezoneInfo; +use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias}; @@ -45,8 +46,21 @@ use crate::utils::make_decimal_type; /// The ContextProvider trait allows the query planner to obtain meta-data about tables and /// functions referenced in SQL statements pub trait ContextProvider { + #[deprecated(since = "32.0.0", note = "please use `get_table_source` instead")] + fn get_table_provider(&self, name: TableReference) -> Result> { + self.get_table_source(name) + } /// Getter for a datasource - fn get_table_provider(&self, name: TableReference) -> Result>; + fn get_table_source(&self, name: TableReference) -> Result>; + /// Getter for a table function + fn get_table_function_source( + &self, + _name: &str, + _args: Vec, + ) -> Result> { + not_impl_err!("Table Functions are not supported") + } + /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description @@ -186,22 +200,22 @@ impl PlannerContext { /// SQL query planner pub struct SqlToRel<'a, S: ContextProvider> { - pub(crate) schema_provider: &'a S, + pub(crate) context_provider: &'a S, pub(crate) options: ParserOptions, pub(crate) normalizer: IdentNormalizer, } impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Create a new query planner - pub fn new(schema_provider: &'a S) -> Self { - Self::new_with_options(schema_provider, ParserOptions::default()) + pub fn new(context_provider: &'a S) -> Self { + Self::new_with_options(context_provider, ParserOptions::default()) } /// Create a new query planner - pub fn new_with_options(schema_provider: &'a S, options: ParserOptions) -> Self { + pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self { let normalize = options.enable_ident_normalization; SqlToRel { - schema_provider, + context_provider, options, normalizer: IdentNormalizer::new(normalize), } @@ -226,6 +240,42 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Schema::new(fields)) } + /// Returns a vector of (column_name, default_expr) pairs + pub(super) fn build_column_defaults( + &self, + columns: &Vec, + planner_context: &mut PlannerContext, + ) -> Result> { + let mut column_defaults = vec![]; + // Default expressions are restricted, column references are not allowed + let empty_schema = DFSchema::empty(); + let error_desc = |e: DataFusionError| match e { + DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }, _) => { + plan_datafusion_err!( + "Column reference is not allowed in the DEFAULT expression : {}", + e + ) + } + _ => e, + }; + + for column in columns { + if let Some(default_sql_expr) = + column.options.iter().find_map(|o| match &o.option { + ColumnOption::Default(expr) => Some(expr), + _ => None, + }) + { + let default_expr = self + .sql_to_expr(default_sql_expr.clone(), &empty_schema, planner_context) + .map_err(error_desc)?; + column_defaults + .push((self.normalizer.normalize(column.name.clone()), default_expr)); + } + } + Ok(column_defaults) + } + /// Apply the given TableAlias to the input plan pub(crate) fn apply_table_alias( &self, @@ -293,14 +343,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { match sql_type { - SQLDataType::Array(Some(inner_sql_type)) => { + SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) + | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_sql_type)) => { let data_type = self.convert_simple_data_type(inner_sql_type)?; Ok(DataType::List(Arc::new(Field::new( "field", data_type, true, )))) } - SQLDataType::Array(None) => { + SQLDataType::Array(ArrayElemTypeDef::None) => { not_impl_err!("Arrays with unspecified type is not supported") } other => self.convert_simple_data_type(other), @@ -326,7 +377,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLDataType::Char(_) | SQLDataType::Varchar(_) | SQLDataType::Text - | SQLDataType::String => Ok(DataType::Utf8), + | SQLDataType::String(_) => Ok(DataType::Utf8), SQLDataType::Timestamp(None, tz_info) => { let tz = if matches!(tz_info, TimezoneInfo::Tz) || matches!(tz_info, TimezoneInfo::WithTimeZone) @@ -334,7 +385,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Timestamp With Time Zone // INPUT : [SQLDataType] TimestampTz + [RuntimeConfig] Time Zone // OUTPUT: [ArrowDataType] Timestamp - self.schema_provider.options().execution.time_zone.clone() + self.context_provider.options().execution.time_zone.clone() } else { // Timestamp Without Time zone None @@ -396,7 +447,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::Dec(_) | SQLDataType::BigNumeric(_) | SQLDataType::BigDecimal(_) - | SQLDataType::Clob(_) => not_impl_err!( + | SQLDataType::Clob(_) + | SQLDataType::Bytes(_) + | SQLDataType::Int64 + | SQLDataType::Float64 + | SQLDataType::Struct(_) + => not_impl_err!( "Unsupported SQL type {sql_type:?}" ), } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index fc2a3fb9a57b..388377e3ee6b 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -23,7 +23,7 @@ use datafusion_common::{ not_impl_err, plan_err, sql_err, Constraints, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - CreateMemoryTable, DdlStatement, Expr, LogicalPlan, LogicalPlanBuilder, + CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, Value, @@ -54,7 +54,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Process CTEs from top to bottom // do not allow self-references if with.recursive { - return not_impl_err!("Recursive CTEs are not supported"); + if self + .context_provider + .options() + .execution + .enable_recursive_ctes + { + return plan_err!( + "Recursive CTEs are enabled but are not yet supported" + ); + } else { + return not_impl_err!("Recursive CTEs are not supported"); + } } for cte in with.cte_tables { @@ -90,6 +101,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: Arc::new(plan), if_not_exists: false, or_replace: false, + column_defaults: vec![], })) } _ => plan, @@ -160,7 +172,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let order_by_rex = - self.order_by_to_sort_expr(&order_by, plan.schema(), planner_context)?; - LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() + self.order_by_to_sort_expr(&order_by, plan.schema(), planner_context, true)?; + + if let LogicalPlan::Distinct(Distinct::On(ref distinct_on)) = plan { + // In case of `DISTINCT ON` we must capture the sort expressions since during the plan + // optimization we're effectively doing a `first_value` aggregation according to them. + let distinct_on = distinct_on.clone().with_sort_expr(order_by_rex)?; + Ok(LogicalPlan::Distinct(Distinct::On(distinct_on))) + } else { + LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() + } } } diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index 0113f337e6dc..b119672eae5f 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -132,12 +132,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // parse ON expression let expr = self.sql_to_expr(sql_expr, &join_schema, planner_context)?; LogicalPlanBuilder::from(left) - .join( - right, - join_type, - (Vec::::new(), Vec::::new()), - Some(expr), - )? + .join_on(right, join_type, Some(expr))? .build() } JoinConstraint::Using(idents) => { diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index a01a9a2fb8db..b233f47a058f 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -16,9 +16,11 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_common::{ + not_impl_err, plan_err, DFSchema, DataFusionError, Result, TableReference, +}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; -use sqlparser::ast::TableFactor; +use sqlparser::ast::{FunctionArg, FunctionArgExpr, TableFactor}; mod join; @@ -30,24 +32,58 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, ) -> Result { let (plan, alias) = match relation { - TableFactor::Table { name, alias, .. } => { - // normalize name and alias - let table_ref = self.object_name_to_table_reference(name)?; - let table_name = table_ref.to_string(); - let cte = planner_context.get_cte(&table_name); - ( - match ( - cte, - self.schema_provider.get_table_provider(table_ref.clone()), - ) { - (Some(cte_plan), _) => Ok(cte_plan.clone()), - (_, Ok(provider)) => { - LogicalPlanBuilder::scan(table_ref, provider, None)?.build() - } - (None, Err(e)) => Err(e), - }?, - alias, - ) + TableFactor::Table { + name, alias, args, .. + } => { + if let Some(func_args) = args { + let tbl_func_name = name.0.first().unwrap().value.to_string(); + let args = func_args + .into_iter() + .flat_map(|arg| { + if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) = arg + { + self.sql_expr_to_logical_expr( + expr, + &DFSchema::empty(), + planner_context, + ) + } else { + plan_err!("Unsupported function argument type: {:?}", arg) + } + }) + .collect::>(); + let provider = self + .context_provider + .get_table_function_source(&tbl_func_name, args)?; + let plan = LogicalPlanBuilder::scan( + TableReference::Bare { + table: std::borrow::Cow::Borrowed("tmp_table"), + }, + provider, + None, + )? + .build()?; + (plan, alias) + } else { + // normalize name and alias + let table_ref = self.object_name_to_table_reference(name)?; + let table_name = table_ref.to_string(); + let cte = planner_context.get_cte(&table_name); + ( + match ( + cte, + self.context_provider.get_table_source(table_ref.clone()), + ) { + (Some(cte_plan), _) => Ok(cte_plan.clone()), + (_, Ok(provider)) => { + LogicalPlanBuilder::scan(table_ref, provider, None)? + .build() + } + (None, Err(e)) => Err(e), + }?, + alias, + ) + } } TableFactor::Derived { subquery, alias, .. diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 2062afabfc1a..a0819e4aaf8e 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -25,10 +25,7 @@ use crate::utils::{ }; use datafusion_common::Column; -use datafusion_common::{ - get_target_functional_dependencies, not_impl_err, plan_err, DFSchemaRef, - DataFusionError, Result, -}; +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::expr::Alias; use datafusion_expr::expr_rewriter::{ normalize_col, normalize_col_with_schemas_and_ambiguity_check, @@ -76,7 +73,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); // process `where` clause - let plan = self.plan_selection(select.selection, plan, planner_context)?; + let base_plan = self.plan_selection(select.selection, plan, planner_context)?; // handle named windows before processing the projection expression check_conflicting_windows(&select.named_window)?; @@ -84,16 +81,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // process the SELECT expressions, with wildcards expanded. let select_exprs = self.prepare_select_exprs( - &plan, + &base_plan, select.projection, empty_from, planner_context, )?; // having and group by clause may reference aliases defined in select projection - let projected_plan = self.project(plan.clone(), select_exprs.clone())?; + let projected_plan = self.project(base_plan.clone(), select_exprs.clone())?; let mut combined_schema = (**projected_plan.schema()).clone(); - combined_schema.merge(plan.schema()); + combined_schema.merge(base_plan.schema()); // this alias map is resolved and looked up in both having exprs and group by exprs let alias_map = extract_aliases(&select_exprs); @@ -148,7 +145,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?; // aliases from the projection can conflict with same-named expressions in the input let mut alias_map = alias_map.clone(); - for f in plan.schema().fields() { + for f in base_plan.schema().fields() { alias_map.remove(f.name()); } let group_by_expr = @@ -158,7 +155,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .unwrap_or(group_by_expr); let group_by_expr = normalize_col(group_by_expr, &projected_plan)?; self.validate_schema_satisfies_exprs( - plan.schema(), + base_plan.schema(), &[group_by_expr.clone()], )?; Ok(group_by_expr) @@ -170,11 +167,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { select_exprs .iter() .filter(|select_expr| match select_expr { - Expr::AggregateFunction(_) | Expr::AggregateUDF(_) => false, - Expr::Alias(Alias { expr, name: _ }) => !matches!( - **expr, - Expr::AggregateFunction(_) | Expr::AggregateUDF(_) - ), + Expr::AggregateFunction(_) => false, + Expr::Alias(Alias { expr, name: _, .. }) => { + !matches!(**expr, Expr::AggregateFunction(_)) + } _ => true, }) .cloned() @@ -187,16 +183,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { || !aggr_exprs.is_empty() { self.aggregate( - plan, + &base_plan, &select_exprs, having_expr_opt.as_ref(), - group_by_exprs, - aggr_exprs, + &group_by_exprs, + &aggr_exprs, )? } else { match having_expr_opt { Some(having_expr) => return plan_err!("HAVING clause references: {having_expr} must appear in the GROUP BY clause or be used in an aggregate function"), - None => (plan, select_exprs, having_expr_opt) + None => (base_plan.clone(), select_exprs.clone(), having_expr_opt) } }; @@ -229,19 +225,31 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = project(plan, select_exprs_post_aggr)?; // process distinct clause - let distinct = select - .distinct - .map(|distinct| match distinct { - Distinct::Distinct => Ok(true), - Distinct::On(_) => not_impl_err!("DISTINCT ON Exprs not supported"), - }) - .transpose()? - .unwrap_or(false); + let plan = match select.distinct { + None => Ok(plan), + Some(Distinct::Distinct) => { + LogicalPlanBuilder::from(plan).distinct()?.build() + } + Some(Distinct::On(on_expr)) => { + if !aggr_exprs.is_empty() + || !group_by_exprs.is_empty() + || !window_func_exprs.is_empty() + { + return not_impl_err!("DISTINCT ON expressions with GROUP BY, aggregation or window functions are not supported "); + } - let plan = if distinct { - LogicalPlanBuilder::from(plan).distinct()?.build() - } else { - Ok(plan) + let on_expr = on_expr + .into_iter() + .map(|e| { + self.sql_expr_to_logical_expr(e, plan.schema(), planner_context) + }) + .collect::>>()?; + + // Build the final plan + return LogicalPlanBuilder::from(base_plan) + .distinct_on(on_expr, select_exprs, None)? + .build(); + } }?; // DISTRIBUTE BY @@ -373,7 +381,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[&[plan.schema()]], &plan.using_columns()?, )?; - let expr = Expr::Alias(Alias::new(col, self.normalizer.normalize(alias))); + let name = self.normalizer.normalize(alias); + // avoiding adding an alias if the column name is the same. + let expr = match &col { + Expr::Column(column) if column.name.eq(&name) => col, + _ => col.alias(name), + }; Ok(vec![expr]) } SelectItem::Wildcard(options) => { @@ -471,6 +484,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .clone(); *expr = Expr::Alias(Alias { expr: Box::new(new_expr), + relation: None, name: name.clone(), }); } @@ -511,20 +525,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// the aggregate fn aggregate( &self, - input: LogicalPlan, + input: &LogicalPlan, select_exprs: &[Expr], having_expr_opt: Option<&Expr>, - group_by_exprs: Vec, - aggr_exprs: Vec, + group_by_exprs: &[Expr], + aggr_exprs: &[Expr], ) -> Result<(LogicalPlan, Vec, Option)> { - let group_by_exprs = - get_updated_group_by_exprs(&group_by_exprs, select_exprs, input.schema())?; - // create the aggregate plan let plan = LogicalPlanBuilder::from(input.clone()) - .aggregate(group_by_exprs.clone(), aggr_exprs.clone())? + .aggregate(group_by_exprs.to_vec(), aggr_exprs.to_vec())? .build()?; + let group_by_exprs = if let LogicalPlan::Aggregate(agg) = &plan { + &agg.group_expr + } else { + unreachable!(); + }; + // in this next section of code we are re-writing the projection to refer to columns // output by the aggregate plan. For example, if the projection contains the expression // `SUM(a)` then we replace that with a reference to a column `SUM(a)` produced by @@ -533,7 +550,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // combine the original grouping and aggregate expressions into one list (note that // we do not add the "having" expression since that is not part of the projection) let mut aggr_projection_exprs = vec![]; - for expr in &group_by_exprs { + for expr in group_by_exprs { match expr { Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { aggr_projection_exprs.extend_from_slice(exprs) @@ -549,25 +566,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { _ => aggr_projection_exprs.push(expr.clone()), } } - aggr_projection_exprs.extend_from_slice(&aggr_exprs); + aggr_projection_exprs.extend_from_slice(aggr_exprs); // now attempt to resolve columns and replace with fully-qualified columns let aggr_projection_exprs = aggr_projection_exprs .iter() - .map(|expr| resolve_columns(expr, &input)) + .map(|expr| resolve_columns(expr, input)) .collect::>>()?; // next we replace any expressions that are not a column with a column referencing // an output column from the aggregate schema let column_exprs_post_aggr = aggr_projection_exprs .iter() - .map(|expr| expr_as_column_expr(expr, &input)) + .map(|expr| expr_as_column_expr(expr, input)) .collect::>>()?; // next we re-write the projection let select_exprs_post_aggr = select_exprs .iter() - .map(|expr| rebase_expr(expr, &aggr_projection_exprs, &input)) + .map(|expr| rebase_expr(expr, &aggr_projection_exprs, input)) .collect::>>()?; // finally, we have some validation that the re-written projection can be resolved @@ -582,7 +599,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // aggregation. let having_expr_post_aggr = if let Some(having_expr) = having_expr_opt { let having_expr_post_aggr = - rebase_expr(having_expr, &aggr_projection_exprs, &input)?; + rebase_expr(having_expr, &aggr_projection_exprs, input)?; check_columns_satisfy_exprs( &column_exprs_post_aggr, @@ -642,61 +659,3 @@ fn match_window_definitions( } Ok(()) } - -/// Update group by exprs, according to functional dependencies -/// The query below -/// -/// SELECT sn, amount -/// FROM sales_global -/// GROUP BY sn -/// -/// cannot be calculated, because it has a column(`amount`) which is not -/// part of group by expression. -/// However, if we know that, `sn` is determinant of `amount`. We can -/// safely, determine value of `amount` for each distinct `sn`. For these cases -/// we rewrite the query above as -/// -/// SELECT sn, amount -/// FROM sales_global -/// GROUP BY sn, amount -/// -/// Both queries, are functionally same. \[Because, (`sn`, `amount`) and (`sn`) -/// defines the identical groups. \] -/// This function updates group by expressions such that select expressions that are -/// not in group by expression, are added to the group by expressions if they are dependent -/// of the sub-set of group by expressions. -fn get_updated_group_by_exprs( - group_by_exprs: &[Expr], - select_exprs: &[Expr], - schema: &DFSchemaRef, -) -> Result> { - let mut new_group_by_exprs = group_by_exprs.to_vec(); - let fields = schema.fields(); - let group_by_expr_names = group_by_exprs - .iter() - .map(|group_by_expr| group_by_expr.display_name()) - .collect::>>()?; - // Get targets that can be used in a select, even if they do not occur in aggregation: - if let Some(target_indices) = - get_target_functional_dependencies(schema, &group_by_expr_names) - { - // Calculate dependent fields names with determinant GROUP BY expression: - let associated_field_names = target_indices - .iter() - .map(|idx| fields[*idx].qualified_name()) - .collect::>(); - // Expand GROUP BY expressions with select expressions: If a GROUP - // BY expression is a determinant key, we can use its dependent - // columns in select statements also. - for expr in select_exprs { - let expr_name = format!("{}", expr); - if !new_group_by_exprs.contains(expr) - && associated_field_names.contains(&expr_name) - { - new_group_by_exprs.push(expr.clone()); - } - } - } - - Ok(new_group_by_exprs) -} diff --git a/datafusion/sql/src/set_expr.rs b/datafusion/sql/src/set_expr.rs index e771a5ba3de4..7300d49be0f5 100644 --- a/datafusion/sql/src/set_expr.rs +++ b/datafusion/sql/src/set_expr.rs @@ -44,6 +44,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SetQuantifier::AllByName => { return not_impl_err!("UNION ALL BY NAME not implemented") } + SetQuantifier::DistinctByName => { + return not_impl_err!("UNION DISTINCT BY NAME not implemented") + } }; let left_plan = self.set_expr_to_plan(*left, planner_context)?; diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index ab19fa716c9b..b9fb4c65dc2c 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -15,9 +15,12 @@ // specific language governing permissions and limitations // under the License. +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::sync::Arc; + use crate::parser::{ - CopyToSource, CopyToStatement, CreateExternalTable, DFParser, DescribeTableStmt, - ExplainStatement, LexOrdering, Statement as DFStatement, + CopyToSource, CopyToStatement, CreateExternalTable, DFParser, ExplainStatement, + LexOrdering, Statement as DFStatement, }; use crate::planner::{ object_name_to_qualifier, ContextProvider, PlannerContext, SqlToRel, @@ -28,12 +31,12 @@ use arrow_schema::DataType; use datafusion_common::file_options::StatementOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - not_impl_err, unqualified_field_not_found, Column, Constraints, DFField, DFSchema, - DFSchemaRef, DataFusionError, ExprSchema, OwnedTableReference, Result, - SchemaReference, TableReference, ToDFSchema, + not_impl_err, plan_datafusion_err, plan_err, schema_err, unqualified_field_not_found, + Column, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, + OwnedTableReference, Result, ScalarValue, SchemaError, SchemaReference, + TableReference, ToDFSchema, }; use datafusion_expr::dml::{CopyOptions, CopyTo}; -use datafusion_expr::expr::Placeholder; use datafusion_expr::expr_rewriter::normalize_col_with_schemas_and_ambiguity_check; use datafusion_expr::logical_plan::builder::project; use datafusion_expr::logical_plan::DdlStatement; @@ -49,16 +52,12 @@ use datafusion_expr::{ }; use sqlparser::ast; use sqlparser::ast::{ - Assignment, Expr as SQLExpr, Expr, Ident, ObjectName, ObjectType, Query, SchemaName, - SetExpr, ShowCreateObject, ShowStatementFilter, Statement, TableFactor, - TableWithJoins, TransactionMode, UnaryOperator, Value, + Assignment, ColumnDef, Expr as SQLExpr, Expr, Ident, ObjectName, ObjectType, Query, + SchemaName, SetExpr, ShowCreateObject, ShowStatementFilter, Statement, + TableConstraint, TableFactor, TableWithJoins, TransactionMode, UnaryOperator, Value, }; use sqlparser::parser::ParserError::ParserError; -use datafusion_common::plan_err; -use std::collections::{BTreeMap, HashMap, HashSet}; -use std::sync::Arc; - fn ident_to_string(ident: &Ident) -> String { normalize_ident(ident.to_owned()) } @@ -84,13 +83,60 @@ fn get_schema_name(schema_name: &SchemaName) -> String { } } +/// Construct `TableConstraint`(s) for the given columns by iterating over +/// `columns` and extracting individual inline constraint definitions. +fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec { + let mut constraints = vec![]; + for column in columns { + for ast::ColumnOptionDef { name, option } in &column.options { + match option { + ast::ColumnOption::Unique { is_primary } => { + constraints.push(ast::TableConstraint::Unique { + name: name.clone(), + columns: vec![column.name.clone()], + is_primary: *is_primary, + }) + } + ast::ColumnOption::ForeignKey { + foreign_table, + referred_columns, + on_delete, + on_update, + } => constraints.push(ast::TableConstraint::ForeignKey { + name: name.clone(), + columns: vec![], + foreign_table: foreign_table.clone(), + referred_columns: referred_columns.to_vec(), + on_delete: *on_delete, + on_update: *on_update, + }), + ast::ColumnOption::Check(expr) => { + constraints.push(ast::TableConstraint::Check { + name: name.clone(), + expr: Box::new(expr.clone()), + }) + } + // Other options are not constraint related. + ast::ColumnOption::Default(_) + | ast::ColumnOption::Null + | ast::ColumnOption::NotNull + | ast::ColumnOption::DialectSpecific(_) + | ast::ColumnOption::CharacterSet(_) + | ast::ColumnOption::Generated { .. } + | ast::ColumnOption::Comment(_) + | ast::ColumnOption::OnUpdate(_) => {} + } + } + } + constraints +} + impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Generate a logical plan from an DataFusion SQL statement pub fn statement_to_plan(&self, statement: DFStatement) -> Result { match statement { DFStatement::CreateExternalTable(s) => self.external_table_to_plan(s), DFStatement::Statement(s) => self.sql_statement_to_plan(*s), - DFStatement::DescribeTableStmt(s) => self.describe_table_to_plan(s), DFStatement::CopyTo(s) => self.copy_to_plan(s), DFStatement::Explain(ExplainStatement { verbose, @@ -124,6 +170,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { let sql = Some(statement.to_string()); match statement { + Statement::ExplainTable { + describe_alias: true, // only parse 'DESCRIBE table_name' and not 'EXPLAIN table_name' + table_name, + } => self.describe_table_to_plan(table_name), Statement::Explain { verbose, statement, @@ -154,18 +204,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { or_replace, .. } if table_properties.is_empty() && with_options.is_empty() => { - let mut constraints = constraints; - for column in &columns { - for option in &column.options { - if let ast::ColumnOption::Unique { is_primary } = option.option { - constraints.push(ast::TableConstraint::Unique { - name: None, - columns: vec![column.name.clone()], - is_primary, - }) - } - } - } + // Merge inline constraints and existing constraints + let mut all_constraints = constraints; + let inline_constraints = calc_inline_constraints_from_columns(&columns); + all_constraints.extend(inline_constraints); + // Build column default values + let column_defaults = + self.build_column_defaults(&columns, planner_context)?; match query { Some(query) => { let plan = self.query_to_plan(*query, planner_context)?; @@ -201,7 +246,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; let constraints = Constraints::new_from_table_constraints( - &constraints, + &all_constraints, plan.schema(), )?; @@ -212,6 +257,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: Arc::new(plan), if_not_exists, or_replace, + column_defaults, }, ))) } @@ -224,7 +270,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; let plan = LogicalPlan::EmptyRelation(plan); let constraints = Constraints::new_from_table_constraints( - &constraints, + &all_constraints, plan.schema(), )?; Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( @@ -234,6 +280,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: Arc::new(plan), if_not_exists, or_replace, + column_defaults, }, ))) } @@ -392,6 +439,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table, on, returning, + ignore, } => { if or.is_some() { plan_err!("Inserts with or clauses not supported")?; @@ -411,6 +459,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if returning.is_some() { plan_err!("Insert-returning clause not supported")?; } + if ignore { + plan_err!("Insert-ignore clause not supported")?; + } + let Some(source) = source else { + plan_err!("Inserts without a source not supported")? + }; let _ = into; // optional keyword doesn't change behavior self.insert_to_plan(table_name, columns, source, overwrite) } @@ -433,6 +487,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { selection, returning, from, + order_by, + limit, } => { if !tables.is_empty() { plan_err!("DELETE not supported")?; @@ -445,6 +501,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if returning.is_some() { plan_err!("Delete-returning clause not yet supported")?; } + + if !order_by.is_empty() { + plan_err!("Delete-order-by clause not yet supported")?; + } + + if limit.is_some() { + plan_err!("Delete-limit clause not yet supported")?; + } + let table_name = self.get_delete_target(from)?; self.delete_to_plan(table_name, selection) } @@ -452,7 +517,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Statement::StartTransaction { modes, begin: false, + modifier, } => { + if let Some(modifier) = modifier { + return not_impl_err!( + "Transaction modifier not supported: {modifier}" + ); + } let isolation_level: ast::TransactionIsolationLevel = modes .iter() .filter_map(|m: &ast::TransactionMode| match m { @@ -508,7 +579,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }); Ok(LogicalPlan::Statement(statement)) } - Statement::Rollback { chain } => { + Statement::Rollback { chain, savepoint } => { + if savepoint.is_some() { + plan_err!("Savepoints not supported")?; + } let statement = PlanStatement::TransactionEnd(TransactionEnd { conclusion: TransactionConclusion::Rollback, chain, @@ -565,14 +639,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - fn describe_table_to_plan( - &self, - statement: DescribeTableStmt, - ) -> Result { - let DescribeTableStmt { table_name } = statement; + fn describe_table_to_plan(&self, table_name: ObjectName) -> Result { let table_ref = self.object_name_to_table_reference(table_name)?; - let table_source = self.schema_provider.get_table_provider(table_ref)?; + let table_source = self.context_provider.get_table_source(table_ref)?; let schema = table_source.schema(); @@ -591,7 +661,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { CopyToSource::Relation(object_name) => { let table_ref = self.object_name_to_table_reference(object_name.clone())?; - let table_source = self.schema_provider.get_table_provider(table_ref)?; + let table_source = self.context_provider.get_table_source(table_ref)?; LogicalPlanBuilder::scan( object_name_to_string(&object_name), table_source, @@ -646,7 +716,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut all_results = vec![]; for expr in order_exprs { // Convert each OrderByExpr to a SortExpr: - let expr_vec = self.order_by_to_sort_expr(&expr, schema, planner_context)?; + let expr_vec = + self.order_by_to_sort_expr(&expr, schema, planner_context, true)?; // Verify that columns of all SortExprs exist in the schema: for expr in expr_vec.iter() { for column in expr.to_columns()?.iter() { @@ -681,8 +752,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { order_exprs, unbounded, options, + constraints, } = statement; + // Merge inline constraints and existing constraints + let mut all_constraints = constraints; + let inline_constraints = calc_inline_constraints_from_columns(&columns); + all_constraints.extend(inline_constraints); + if (file_type == "PARQUET" || file_type == "AVRO" || file_type == "ARROW") && file_compression_type != CompressionTypeVariant::UNCOMPRESSED { @@ -691,15 +768,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?; } + let mut planner_context = PlannerContext::new(); + + let column_defaults = self + .build_column_defaults(&columns, &mut planner_context)? + .into_iter() + .collect(); + let schema = self.build_schema(columns)?; let df_schema = schema.to_dfschema_ref()?; let ordered_exprs = - self.build_order_by(order_exprs, &df_schema, &mut PlannerContext::new())?; + self.build_order_by(order_exprs, &df_schema, &mut planner_context)?; // External tables do not support schemas at the moment, so the name is just a table name let name = OwnedTableReference::bare(name); - + let constraints = + Constraints::new_from_table_constraints(&all_constraints, &df_schema)?; Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable( PlanCreateExternalTable { schema: df_schema, @@ -715,6 +800,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { order_exprs: ordered_exprs, unbounded, options, + constraints, + column_defaults, }, ))) } @@ -757,28 +844,34 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } fn show_variable_to_plan(&self, variable: &[Ident]) -> Result { - let variable = object_name_to_string(&ObjectName(variable.to_vec())); - if !self.has_table("information_schema", "df_settings") { return plan_err!( "SHOW [VARIABLE] is not supported unless information_schema is enabled" ); } - let variable_lower = variable.to_lowercase(); + let verbose = variable + .last() + .map(|s| ident_to_string(s) == "verbose") + .unwrap_or(false); + let mut variable_vec = variable.to_vec(); + let mut columns: String = "name, value".to_owned(); - let query = if variable_lower == "all" { + if verbose { + columns = format!("{columns}, description"); + variable_vec = variable_vec.split_at(variable_vec.len() - 1).0.to_vec(); + } + + let variable = object_name_to_string(&ObjectName(variable_vec)); + let base_query = format!("SELECT {columns} FROM information_schema.df_settings"); + let query = if variable == "all" { // Add an ORDER BY so the output comes out in a consistent order - String::from( - "SELECT name, setting FROM information_schema.df_settings ORDER BY name", - ) - } else if variable_lower == "timezone" || variable_lower == "time.zone" { + format!("{base_query} ORDER BY name") + } else if variable == "timezone" || variable == "time.zone" { // we could introduce alias in OptionDefinition if this string matching thing grows - String::from("SELECT name, setting FROM information_schema.df_settings WHERE name = 'datafusion.execution.time_zone'") + format!("{base_query} WHERE name = 'datafusion.execution.time_zone'") } else { - format!( - "SELECT name, setting FROM information_schema.df_settings WHERE name = '{variable}'" - ) + format!("{base_query} WHERE name = '{variable}'") }; let mut rewrite = DFParser::parse_sql(&query)?; @@ -859,12 +952,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(table_name.clone())?; - let provider = self.schema_provider.get_table_provider(table_ref.clone())?; - let schema = (*provider.schema()).clone(); + let table_source = self.context_provider.get_table_source(table_ref.clone())?; + let schema = (*table_source.schema()).clone(); let schema = DFSchema::try_from(schema)?; - let scan = - LogicalPlanBuilder::scan(object_name_to_string(&table_name), provider, None)? - .build()?; + let scan = LogicalPlanBuilder::scan( + object_name_to_string(&table_name), + table_source, + None, + )? + .build()?; let mut planner_context = PlannerContext::new(); let source = match predicate_expr { @@ -900,53 +996,39 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { from: Option, predicate_expr: Option, ) -> Result { - let table_name = match &table.relation { - TableFactor::Table { name, .. } => name.clone(), + let (table_name, table_alias) = match &table.relation { + TableFactor::Table { name, alias, .. } => (name.clone(), alias.clone()), _ => plan_err!("Cannot update non-table relation!")?, }; // Do a table lookup to verify the table exists let table_name = self.object_name_to_table_reference(table_name)?; - let provider = self - .schema_provider - .get_table_provider(table_name.clone())?; - let arrow_schema = (*provider.schema()).clone(); + let table_source = self.context_provider.get_table_source(table_name.clone())?; let table_schema = Arc::new(DFSchema::try_from_qualified_schema( table_name.clone(), - &arrow_schema, + &table_source.schema(), )?); - let values = table_schema.fields().iter().map(|f| { - ( - f.name().clone(), - ast::Expr::Identifier(ast::Ident::from(f.name().as_str())), - ) - }); // Overwrite with assignment expressions let mut planner_context = PlannerContext::new(); let mut assign_map = assignments .iter() .map(|assign| { - let col_name: &Ident = assign.id.iter().last().ok_or_else(|| { - DataFusionError::Plan("Empty column id".to_string()) - })?; + let col_name: &Ident = assign + .id + .iter() + .last() + .ok_or_else(|| plan_datafusion_err!("Empty column id"))?; // Validate that the assignment target column exists table_schema.field_with_unqualified_name(&col_name.value)?; Ok((col_name.value.clone(), assign.value.clone())) }) .collect::>>()?; - let values = values - .into_iter() - .map(|(k, v)| { - let val = assign_map.remove(&k).unwrap_or(v); - (k, val) - }) - .collect::>(); - - // Build scan - let from = from.unwrap_or(table); - let scan = self.plan_from_tables(vec![from], &mut planner_context)?; + // Build scan, join with from table if it exists. + let mut input_tables = vec![table]; + input_tables.extend(from); + let scan = self.plan_from_tables(input_tables, &mut planner_context)?; // Filter let source = match predicate_expr { @@ -954,43 +1036,59 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Some(predicate_expr) => { let filter_expr = self.sql_to_expr( predicate_expr, - &table_schema, + scan.schema(), &mut planner_context, )?; let mut using_columns = HashSet::new(); expr_to_columns(&filter_expr, &mut using_columns)?; let filter_expr = normalize_col_with_schemas_and_ambiguity_check( filter_expr, - &[&[&table_schema]], + &[&[scan.schema()]], &[using_columns], )?; LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(scan))?) } }; - // Projection - let mut exprs = vec![]; - for (col_name, expr) in values.into_iter() { - let expr = self.sql_to_expr(expr, &table_schema, &mut planner_context)?; - let expr = match expr { - datafusion_expr::Expr::Placeholder(Placeholder { - ref id, - ref data_type, - }) => match data_type { + // Build updated values for each column, using the previous value if not modified + let exprs = table_schema + .fields() + .iter() + .map(|field| { + let expr = match assign_map.remove(field.name()) { + Some(new_value) => { + let mut expr = self.sql_to_expr( + new_value, + source.schema(), + &mut planner_context, + )?; + // Update placeholder's datatype to the type of the target column + if let datafusion_expr::Expr::Placeholder(placeholder) = &mut expr + { + placeholder.data_type = placeholder + .data_type + .take() + .or_else(|| Some(field.data_type().clone())); + } + // Cast to target column type, if necessary + expr.cast_to(field.data_type(), source.schema())? + } None => { - let dt = table_schema.data_type(&Column::from_name(&col_name))?; - datafusion_expr::Expr::Placeholder(Placeholder::new( - id.clone(), - Some(dt.clone()), - )) + // If the target table has an alias, use it to qualify the column name + if let Some(alias) = &table_alias { + datafusion_expr::Expr::Column(Column::new( + Some(self.normalizer.normalize(alias.name.clone())), + field.name(), + )) + } else { + datafusion_expr::Expr::Column(field.qualified_column()) + } } - Some(_) => expr, - }, - _ => expr, - }; - let expr = expr.alias(col_name); - exprs.push(expr); - } + }; + Ok(expr.alias(field.name())) + }) + .collect::>>()?; + let source = project(source, exprs)?; let plan = LogicalPlan::Dml(DmlStatement { @@ -1011,15 +1109,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { // Do a table lookup to verify the table exists let table_name = self.object_name_to_table_reference(table_name)?; - let provider = self - .schema_provider - .get_table_provider(table_name.clone())?; - let arrow_schema = (*provider.schema()).clone(); + let table_source = self.context_provider.get_table_source(table_name.clone())?; + let arrow_schema = (*table_source.schema()).clone(); let table_schema = DFSchema::try_from(arrow_schema)?; - // Get insert fields and index_mapping - // The i-th field of the table is `fields[index_mapping[i]]` - let (fields, index_mapping) = if columns.is_empty() { + // Get insert fields and target table's value indices + // + // if value_indices[i] = Some(j), it means that the value of the i-th target table's column is + // derived from the j-th output of the source. + // + // if value_indices[i] = None, it means that the value of the i-th target table's column is + // not provided, and should be filled with a default value later. + let (fields, value_indices) = if columns.is_empty() { // Empty means we're inserting into all columns of the table ( table_schema.fields().clone(), @@ -1028,7 +1129,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>(), ) } else { - let mut mapping = vec![None; table_schema.fields().len()]; + let mut value_indices = vec![None; table_schema.fields().len()]; let fields = columns .into_iter() .map(|c| self.normalizer.normalize(c)) @@ -1037,19 +1138,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let column_index = table_schema .index_of_column_by_name(None, &c)? .ok_or_else(|| unqualified_field_not_found(&c, &table_schema))?; - if mapping[column_index].is_some() { - return Err(DataFusionError::SchemaError( - datafusion_common::SchemaError::DuplicateUnqualifiedField { - name: c, - }, - )); + if value_indices[column_index].is_some() { + return schema_err!(SchemaError::DuplicateUnqualifiedField { + name: c, + }); } else { - mapping[column_index] = Some(i); + value_indices[column_index] = Some(i); } Ok(table_schema.field(column_index).clone()) }) .collect::>>()?; - (fields, mapping) + (fields, value_indices) }; // infer types for Values clause... other types should be resolvable the regular way @@ -1060,15 +1159,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if let ast::Expr::Value(Value::Placeholder(name)) = val { let name = name.replace('$', "").parse::().map_err(|_| { - DataFusionError::Plan(format!( - "Can't parse placeholder: {name}" - )) + plan_datafusion_err!("Can't parse placeholder: {name}") })? - 1; let field = fields.get(idx).ok_or_else(|| { - DataFusionError::Plan(format!( + plan_datafusion_err!( "Placeholder ${} refers to a non existent column", idx + 1 - )) + ) })?; let dt = field.field().data_type().clone(); let _ = prepare_param_data_types.insert(name, dt); @@ -1086,17 +1183,28 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan_err!("Column count doesn't match insert query!")?; } - let exprs = index_mapping + let exprs = value_indices .into_iter() - .flatten() - .map(|i| { - let target_field = &fields[i]; - let source_field = source.schema().field(i); - let expr = - datafusion_expr::Expr::Column(source_field.unqualified_column()) - .cast_to(target_field.data_type(), source.schema())? - .alias(target_field.name()); - Ok(expr) + .enumerate() + .map(|(i, value_index)| { + let target_field = table_schema.field(i); + let expr = match value_index { + Some(v) => { + let source_field = source.schema().field(v); + datafusion_expr::Expr::Column(source_field.qualified_column()) + .cast_to(target_field.data_type(), source.schema())? + } + // The value is not specified. Fill in the default value for the column. + None => table_source + .get_column_default(target_field.name()) + .cloned() + .unwrap_or_else(|| { + // If there is no default for the column, then the default is NULL + datafusion_expr::Expr::Literal(ScalarValue::Null) + }) + .cast_to(target_field.data_type(), &DFSchema::empty())?, + }; + Ok(expr.alias(target_field.name())) }) .collect::>>()?; let source = project(source, exprs)?; @@ -1140,7 +1248,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(sql_table_name)?; - let _ = self.schema_provider.get_table_provider(table_ref)?; + let _ = self.context_provider.get_table_source(table_ref)?; // treat both FULL and EXTENDED as the same let select_list = if full || extended { @@ -1175,7 +1283,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(sql_table_name)?; - let _ = self.schema_provider.get_table_provider(table_ref)?; + let _ = self.context_provider.get_table_source(table_ref)?; let query = format!( "SELECT table_catalog, table_schema, table_name, definition FROM information_schema.views WHERE {where_clause}" @@ -1192,8 +1300,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: schema.into(), table: table.into(), }; - self.schema_provider - .get_table_provider(tables_reference) + self.context_provider + .get_table_source(tables_reference) .is_ok() } } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 28eaf241fa6f..616a2fc74932 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -17,7 +17,9 @@ //! SQL Utility Functions -use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE}; +use arrow_schema::{ + DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, +}; use datafusion_common::tree_node::{Transformed, TreeNode}; use sqlparser::ast::Ident; @@ -221,14 +223,17 @@ pub(crate) fn make_decimal_type( (None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE), }; - // Arrow decimal is i128 meaning 38 maximum decimal digits if precision == 0 - || precision > DECIMAL128_MAX_PRECISION + || precision > DECIMAL256_MAX_PRECISION || scale.unsigned_abs() > precision { plan_err!( - "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 38`, and `scale <= precision`." + "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`." ) + } else if precision > DECIMAL128_MAX_PRECISION + && precision <= DECIMAL256_MAX_PRECISION + { + Ok(DataType::Decimal256(precision, scale)) } else { Ok(DataType::Decimal128(precision, scale)) } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 88a041a66145..4de08a7124cf 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -22,11 +22,11 @@ use std::{sync::Arc, vec}; use arrow_schema::*; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; -use datafusion_common::plan_err; use datafusion_common::{ assert_contains, config::ConfigOptions, DataFusionError, Result, ScalarValue, TableReference, }; +use datafusion_common::{plan_err, ParamValues}; use datafusion_expr::{ logical_plan::{LogicalPlan, Prepare}, AggregateUDF, ScalarUDF, TableSource, WindowUDF, @@ -201,7 +201,7 @@ fn cast_to_invalid_decimal_type_precision_0() { let sql = "SELECT CAST(10 AS DECIMAL(0))"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: Decimal(precision = 0, scale = 0) should satisfy `0 < precision <= 38`, and `scale <= precision`.", + "Error during planning: Decimal(precision = 0, scale = 0) should satisfy `0 < precision <= 76`, and `scale <= precision`.", err.strip_backtrace() ); } @@ -212,9 +212,19 @@ fn cast_to_invalid_decimal_type_precision_gt_38() { // precision > 38 { let sql = "SELECT CAST(10 AS DECIMAL(39))"; + let plan = "Projection: CAST(Int64(10) AS Decimal256(39, 0))\n EmptyRelation"; + quick_test(sql, plan); + } +} + +#[test] +fn cast_to_invalid_decimal_type_precision_gt_76() { + // precision > 76 + { + let sql = "SELECT CAST(10 AS DECIMAL(79))"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: Decimal(precision = 39, scale = 0) should satisfy `0 < precision <= 38`, and `scale <= precision`.", + "Error during planning: Decimal(precision = 79, scale = 0) should satisfy `0 < precision <= 76`, and `scale <= precision`.", err.strip_backtrace() ); } @@ -227,7 +237,7 @@ fn cast_to_invalid_decimal_type_precision_lt_scale() { let sql = "SELECT CAST(10 AS DECIMAL(5, 10))"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: Decimal(precision = 5, scale = 10) should satisfy `0 < precision <= 38`, and `scale <= precision`.", + "Error during planning: Decimal(precision = 5, scale = 10) should satisfy `0 < precision <= 76`, and `scale <= precision`.", err.strip_backtrace() ); } @@ -412,12 +422,11 @@ CopyTo: format=csv output_url=output.csv single_file_output=true options: () fn plan_insert() { let sql = "insert into person (id, first_name, last_name) values (1, 'Alan', 'Turing')"; - let plan = r#" -Dml: op=[Insert Into] table=[person] - Projection: CAST(column1 AS UInt32) AS id, column2 AS first_name, column3 AS last_name - Values: (Int64(1), Utf8("Alan"), Utf8("Turing")) - "# - .trim(); + let plan = "Dml: op=[Insert Into] table=[person]\ + \n Projection: CAST(column1 AS UInt32) AS id, column2 AS first_name, column3 AS last_name, \ + CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ + CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ + \n Values: (Int64(1), Utf8(\"Alan\"), Utf8(\"Turing\"))"; quick_test(sql, plan); } @@ -462,6 +471,10 @@ Dml: op=[Insert Into] table=[test_decimal] "INSERT INTO person (id, first_name, last_name) VALUES ($2, $4, $6)", "Error during planning: Placeholder type could not be resolved" )] +#[case::placeholder_type_unresolved( + "INSERT INTO person (id, first_name, last_name) VALUES ($id, $first_name, $last_name)", + "Error during planning: Can't parse placeholder: $id" +)] #[test] fn test_insert_schema_errors(#[case] sql: &str, #[case] error: &str) { let err = logical_plan(sql).unwrap_err(); @@ -597,11 +610,9 @@ fn select_compound_filter() { #[test] fn test_timestamp_filter() { let sql = "SELECT state FROM person WHERE birth_date < CAST (158412331400600000 as timestamp)"; - let expected = "Projection: person.state\ - \n Filter: person.birth_date < CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\ + \n Filter: person.birth_date < CAST(CAST(Int64(158412331400600000) AS Timestamp(Second, None)) AS Timestamp(Nanosecond, None))\ \n TableScan: person"; - quick_test(sql, expected); } @@ -745,9 +756,11 @@ fn join_with_ambiguous_column() { #[test] fn where_selection_with_ambiguous_column() { let sql = "SELECT * FROM person a, person b WHERE id = id + 1"; - let err = logical_plan(sql).expect_err("query should have failed"); + let err = logical_plan(sql) + .expect_err("query should have failed") + .strip_backtrace(); assert_eq!( - "SchemaError(AmbiguousReference { field: Column { relation: None, name: \"id\" } })", + "\"Schema error: Ambiguous reference to unqualified field id\"", format!("{err:?}") ); } @@ -1277,6 +1290,16 @@ fn select_simple_aggregate_repeated_aggregate_with_unique_aliases() { ); } +#[test] +fn select_simple_aggregate_respect_nulls() { + let sql = "SELECT MIN(age) RESPECT NULLS FROM person"; + let err = logical_plan(sql).expect_err("query should have failed"); + + assert_contains!( + err.strip_backtrace(), + "This feature is not implemented: Null treatment in aggregate functions is not supported: RESPECT NULLS" + ); +} #[test] fn select_from_typed_string_values() { quick_test( @@ -1364,18 +1387,6 @@ fn select_interval_out_of_range() { ); } -#[test] -fn select_array_no_common_type() { - let sql = "SELECT [1, true, null]"; - let err = logical_plan(sql).expect_err("query should have failed"); - - // HashSet doesn't guarantee order - assert_contains!( - err.strip_backtrace(), - "This feature is not implemented: Arrays with different types are not supported: " - ); -} - #[test] fn recursive_ctes() { let sql = " @@ -1392,16 +1403,6 @@ fn recursive_ctes() { ); } -#[test] -fn select_array_non_literal_type() { - let sql = "SELECT [now()]"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "This feature is not implemented: Arrays with elements other than literal are not supported: now()", - err.strip_backtrace() - ); -} - #[test] fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() { quick_test( @@ -1687,20 +1688,24 @@ fn select_order_by_multiple_index() { #[test] fn select_order_by_index_of_0() { let sql = "SELECT id FROM person ORDER BY 0"; - let err = logical_plan(sql).expect_err("query should have failed"); + let err = logical_plan(sql) + .expect_err("query should have failed") + .strip_backtrace(); assert_eq!( - "Plan(\"Order by index starts at 1 for column indexes\")", - format!("{err:?}") + "Error during planning: Order by index starts at 1 for column indexes", + err ); } #[test] fn select_order_by_index_oob() { let sql = "SELECT id FROM person ORDER BY 2"; - let err = logical_plan(sql).expect_err("query should have failed"); + let err = logical_plan(sql) + .expect_err("query should have failed") + .strip_backtrace(); assert_eq!( - "Plan(\"Order by column out of bounds, specified: 2, max: 1\")", - format!("{err:?}") + "Error during planning: Order by column out of bounds, specified: 2, max: 1", + err ); } @@ -1806,6 +1811,14 @@ fn create_external_table_csv() { quick_test(sql, expected); } +#[test] +fn create_external_table_with_pk() { + let sql = "CREATE EXTERNAL TABLE t(c1 int, primary key(c1)) STORED AS CSV LOCATION 'foo.csv'"; + let expected = + "CreateExternalTable: Bare { table: \"t\" } constraints=[PrimaryKey([0])]"; + quick_test(sql, expected); +} + #[test] fn create_schema_with_quoted_name() { let sql = "CREATE SCHEMA \"quoted_schema_name\""; @@ -2053,24 +2066,6 @@ fn union_all() { quick_test(sql, expected); } -#[test] -fn union_4_combined_in_one() { - let sql = "SELECT order_id from orders - UNION ALL SELECT order_id FROM orders - UNION ALL SELECT order_id FROM orders - UNION ALL SELECT order_id FROM orders"; - let expected = "Union\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.order_id\ - \n TableScan: orders"; - quick_test(sql, expected); -} - #[test] fn union_with_different_column_names() { let sql = "SELECT order_id from orders UNION ALL SELECT customer_id FROM orders"; @@ -2096,13 +2091,12 @@ fn union_values_with_no_alias() { #[test] fn union_with_incompatible_data_type() { let sql = "SELECT interval '1 year 1 day' UNION ALL SELECT 1"; - let err = logical_plan(sql).expect_err("query should have failed"); + let err = logical_plan(sql) + .expect_err("query should have failed") + .strip_backtrace(); assert_eq!( - "Plan(\"UNION Column Int64(1) (type: Int64) is \ - not compatible with column IntervalMonthDayNano\ - (\\\"950737950189618795196236955648\\\") \ - (type: Interval(MonthDayNano))\")", - format!("{err:?}") + "Error during planning: UNION Column Int64(1) (type: Int64) is not compatible with column IntervalMonthDayNano(\"950737950189618795196236955648\") (type: Interval(MonthDayNano))", + err ); } @@ -2205,10 +2199,12 @@ fn union_with_aliases() { #[test] fn union_with_incompatible_data_types() { let sql = "SELECT 'a' a UNION ALL SELECT true a"; - let err = logical_plan(sql).expect_err("query should have failed"); + let err = logical_plan(sql) + .expect_err("query should have failed") + .strip_backtrace(); assert_eq!( - "Plan(\"UNION Column a (type: Boolean) is not compatible with column a (type: Utf8)\")", - format!("{err:?}") + "Error during planning: UNION Column a (type: Boolean) is not compatible with column a (type: Utf8)", + err ); } @@ -2684,7 +2680,7 @@ fn prepare_stmt_quick_test( fn prepare_stmt_replace_params_quick_test( plan: LogicalPlan, - param_values: Vec, + param_values: impl Into, expected_plan: &str, ) -> LogicalPlan { // replace params @@ -2701,7 +2697,7 @@ struct MockContextProvider { } impl ContextProvider for MockContextProvider { - fn get_table_provider(&self, name: TableReference) -> Result> { + fn get_table_source(&self, name: TableReference) -> Result> { let schema = match name.table() { "test" => Ok(Schema::new(vec![ Field::new("t_date32", DataType::Date32, false), @@ -3552,13 +3548,24 @@ fn test_select_unsupported_syntax_errors(#[case] sql: &str, #[case] error: &str) fn select_order_by_with_cast() { let sql = "SELECT first_name AS first_name FROM (SELECT first_name AS first_name FROM person) ORDER BY CAST(first_name as INT)"; - let expected = "Sort: CAST(first_name AS first_name AS Int32) ASC NULLS LAST\ - \n Projection: first_name AS first_name\ - \n Projection: person.first_name AS first_name\ + let expected = "Sort: CAST(person.first_name AS Int32) ASC NULLS LAST\ + \n Projection: person.first_name\ + \n Projection: person.first_name\ \n TableScan: person"; quick_test(sql, expected); } +#[test] +fn test_avoid_add_alias() { + // avoiding adding an alias if the column name is the same. + // plan1 = plan2 + let sql = "select person.id as id from person order by person.id"; + let plan1 = logical_plan(sql).unwrap(); + let sql = "select id from person order by id"; + let plan2 = logical_plan(sql).unwrap(); + assert_eq!(format!("{plan1:?}"), format!("{plan2:?}")); +} + #[test] fn test_duplicated_left_join_key_inner_join() { // person.id * 2 happen twice in left side. @@ -3676,6 +3683,19 @@ fn test_prepare_statement_should_infer_types() { assert_eq!(actual_types, expected_types); } +#[test] +fn test_non_prepare_statement_should_infer_types() { + // Non prepared statements (like SELECT) should also have their parameter types inferred + let sql = "SELECT 1 + $1"; + let plan = logical_plan(sql).unwrap(); + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types = HashMap::from([ + // constant 1 is inferred to be int64 + ("$1".to_string(), Some(DataType::Int64)), + ]); + assert_eq!(actual_types, expected_types); +} + #[test] #[should_panic( expected = "value: SQL(ParserError(\"Expected [NOT] NULL or TRUE|FALSE or [NOT] DISTINCT FROM after IS, found: $1\"" @@ -3723,7 +3743,7 @@ fn test_prepare_statement_to_plan_no_param() { /////////////////// // replace params with values - let param_values = vec![]; + let param_values: Vec = vec![]; let expected_plan = "Projection: person.id, person.age\ \n Filter: person.age = Int64(10)\ \n TableScan: person"; @@ -3737,7 +3757,7 @@ fn test_prepare_statement_to_plan_one_param_no_value_panic() { let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; let plan = logical_plan(sql).unwrap(); // declare 1 param but provide 0 - let param_values = vec![]; + let param_values: Vec = vec![]; assert_eq!( plan.with_param_values(param_values) .unwrap_err() @@ -3850,7 +3870,7 @@ Projection: person.id, orders.order_id assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; + let param_values = vec![ScalarValue::Int32(Some(10))].into(); let expected_plan = r#" Projection: person.id, orders.order_id Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) @@ -3882,7 +3902,7 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; + let param_values = vec![ScalarValue::Int32(Some(10))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age = Int32(10) @@ -3894,6 +3914,41 @@ Projection: person.id, person.age prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } +#[test] +fn test_prepare_statement_infer_types_from_between_predicate() { + let sql = "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2"; + + let expected_plan = r#" +Projection: person.id, person.age + Filter: person.age BETWEEN $1 AND $2 + TableScan: person + "# + .trim(); + + let expected_dt = "[Int32]"; + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types = HashMap::from([ + ("$1".to_string(), Some(DataType::Int32)), + ("$2".to_string(), Some(DataType::Int32)), + ]); + assert_eq!(actual_types, expected_types); + + // replace params with values + let param_values = + vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))].into(); + let expected_plan = r#" +Projection: person.id, person.age + Filter: person.age BETWEEN Int32(10) AND Int32(30) + TableScan: person + "# + .trim(); + let plan = plan.replace_params_with_values(¶m_values).unwrap(); + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); +} + #[test] fn test_prepare_statement_infer_types_subquery() { let sql = "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)"; @@ -3918,7 +3973,7 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::UInt32(Some(10))]; + let param_values = vec![ScalarValue::UInt32(Some(10))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age = () @@ -3958,7 +4013,8 @@ Dml: op=[Update] table=[person] assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))]; + let param_values = + vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))].into(); let expected_plan = r#" Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 @@ -3975,12 +4031,11 @@ Dml: op=[Update] table=[person] fn test_prepare_statement_insert_infer() { let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3)"; - let expected_plan = r#" -Dml: op=[Insert Into] table=[person] - Projection: column1 AS id, column2 AS first_name, column3 AS last_name - Values: ($1, $2, $3) - "# - .trim(); + let expected_plan = "Dml: op=[Insert Into] table=[person]\ + \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ + CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ + CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ + \n Values: ($1, $2, $3)"; let expected_dt = "[Int32]"; let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); @@ -3996,15 +4051,15 @@ Dml: op=[Insert Into] table=[person] // replace params with values let param_values = vec![ ScalarValue::UInt32(Some(1)), - ScalarValue::Utf8(Some("Alan".to_string())), - ScalarValue::Utf8(Some("Turing".to_string())), - ]; - let expected_plan = r#" -Dml: op=[Insert Into] table=[person] - Projection: column1 AS id, column2 AS first_name, column3 AS last_name - Values: (UInt32(1), Utf8("Alan"), Utf8("Turing")) - "# - .trim(); + ScalarValue::from("Alan"), + ScalarValue::from("Turing"), + ] + .into(); + let expected_plan = "Dml: op=[Insert Into] table=[person]\ + \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ + CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ + CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ + \n Values: (UInt32(1), Utf8(\"Alan\"), Utf8(\"Turing\"))"; let plan = plan.replace_params_with_values(¶m_values).unwrap(); prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); @@ -4078,11 +4133,11 @@ fn test_prepare_statement_to_plan_multi_params() { // replace params with values let param_values = vec![ ScalarValue::Int32(Some(10)), - ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::from("abc"), ScalarValue::Float64(Some(100.0)), ScalarValue::Int32(Some(20)), ScalarValue::Float64(Some(200.0)), - ScalarValue::Utf8(Some("xyz".to_string())), + ScalarValue::from("xyz"), ]; let expected_plan = "Projection: person.id, person.age, Utf8(\"xyz\")\ @@ -4148,8 +4203,8 @@ fn test_prepare_statement_to_plan_value_list() { /////////////////// // replace params with values let param_values = vec![ - ScalarValue::Utf8(Some("a".to_string())), - ScalarValue::Utf8(Some("b".to_string())), + ScalarValue::from("a".to_string()), + ScalarValue::from("b".to_string()), ]; let expected_plan = "Projection: t.num, t.letter\ \n SubqueryAlias: t\ diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 2cd16927be2c..e333dc816f66 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -16,50 +16,50 @@ # under the License. [package] -authors.workspace = true -edition.workspace = true -homepage.workspace = true -license.workspace = true +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } name = "datafusion-sqllogictest" -readme.workspace = true -repository.workspace = true -rust-version.workspace = true -version.workspace = true +readme = "README.md" +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } [lib] name = "datafusion_sqllogictest" path = "src/lib.rs" [dependencies] -arrow = {workspace = true} -async-trait = "0.1.41" -bigdecimal = "0.4.1" -datafusion = {path = "../core", version = "31.0.0"} -datafusion-common = {path = "../common", version = "31.0.0", default-features = false} -half = "2.2.1" -itertools = "0.11" -object_store = "0.7.0" -rust_decimal = {version = "1.27.0"} -log = "^0.4" -sqllogictest = "0.17.0" -sqlparser.workspace = true -tempfile = "3" -thiserror = "1.0.44" -tokio = {version = "1.0"} -bytes = {version = "1.4.0", optional = true} -futures = {version = "0.3.28"} +arrow = { workspace = true } +async-trait = { workspace = true } +bigdecimal = { workspace = true } +bytes = { version = "1.4.0", optional = true } chrono = { workspace = true, optional = true } -tokio-postgres = {version = "0.7.7", optional = true} -postgres-types = {version = "0.2.4", optional = true} -postgres-protocol = {version = "0.6.4", optional = true} +datafusion = { path = "../core", version = "34.0.0" } +datafusion-common = { workspace = true } +futures = { version = "0.3.28" } +half = { workspace = true } +itertools = { workspace = true } +log = { workspace = true } +object_store = { workspace = true } +postgres-protocol = { version = "0.6.4", optional = true } +postgres-types = { version = "0.2.4", optional = true } +rust_decimal = { version = "1.27.0" } +sqllogictest = "0.19.0" +sqlparser = { workspace = true } +tempfile = { workspace = true } +thiserror = { workspace = true } +tokio = { version = "1.0" } +tokio-postgres = { version = "0.7.7", optional = true } [features] -postgres = ["bytes", "chrono", "tokio-postgres", "postgres-types", "postgres-protocol"] avro = ["datafusion/avro"] +postgres = ["bytes", "chrono", "tokio-postgres", "postgres-types", "postgres-protocol"] [dev-dependencies] -env_logger = "0.10" -num_cpus = "1.13.0" +env_logger = { workspace = true } +num_cpus = { workspace = true } [[test]] harness = false diff --git a/datafusion/sqllogictest/README.md b/datafusion/sqllogictest/README.md index 3e94859d35a7..bda00a2dce0f 100644 --- a/datafusion/sqllogictest/README.md +++ b/datafusion/sqllogictest/README.md @@ -17,19 +17,26 @@ under the License. --> -#### Overview +# DataFusion sqllogictest -This is the Datafusion implementation of [sqllogictest](https://www.sqlite.org/sqllogictest/doc/trunk/about.wiki). We -use [sqllogictest-rs](https://github.com/risinglightdb/sqllogictest-rs) as a parser/runner of `.slt` files -in [`test_files`](test_files). +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. -#### Testing setup +This crate is a submodule of DataFusion that contains an implementation of [sqllogictest](https://www.sqlite.org/sqllogictest/doc/trunk/about.wiki). + +[df]: https://crates.io/crates/datafusion + +## Overview + +This crate uses [sqllogictest-rs](https://github.com/risinglightdb/sqllogictest-rs) to parse and run `.slt` files in the +[`test_files`](test_files) directory of this crate. + +## Testing setup 1. `rustup update stable` DataFusion uses the latest stable release of rust 2. `git submodule init` 3. `git submodule update` -#### Running tests: TLDR Examples +## Running tests: TLDR Examples ```shell # Run all tests @@ -56,7 +63,7 @@ cargo test --test sqllogictests -- ddl --complete RUST_LOG=debug cargo test --test sqllogictests -- ddl ``` -#### Cookbook: Adding Tests +## Cookbook: Adding Tests 1. Add queries @@ -95,11 +102,11 @@ SELECT * from foo; Assuming it looks good, check it in! -#### Reference +# Reference -#### Running tests: Validation Mode +## Running tests: Validation Mode -In this model, `sqllogictests` runs the statements and queries in a `.slt` file, comparing the expected output in the +In this mode, `sqllogictests` runs the statements and queries in a `.slt` file, comparing the expected output in the file to the output produced by that run. For example, to run all tests suites in validation mode @@ -115,10 +122,10 @@ sqllogictests also supports `cargo test` style substring matches on file names t cargo test --test sqllogictests -- information ``` -#### Running tests: Postgres compatibility +## Running tests: Postgres compatibility Test files that start with prefix `pg_compat_` verify compatibility -with Postgres by running the same script files both with DataFusion and with Posgres +with Postgres by running the same script files both with DataFusion and with Postgres In order to run the sqllogictests running against a previously running Postgres instance, do: @@ -145,7 +152,7 @@ docker run \ postgres ``` -#### Running Tests: `tpch` +## Running Tests: `tpch` Test files in `tpch` directory runs against the `TPCH` data set (SF = 0.1), which must be generated before running. You can use following @@ -165,7 +172,7 @@ Then you need to add `INCLUDE_TPCH=true` to run tpch tests: INCLUDE_TPCH=true cargo test --test sqllogictests ``` -#### Updating tests: Completion Mode +## Updating tests: Completion Mode In test script completion mode, `sqllogictests` reads a prototype script and runs the statements and queries against the database engine. The output is a full script that is a copy of the prototype script with result inserted. @@ -177,7 +184,7 @@ You can update the tests / generate expected output by passing the `--complete` cargo test --test sqllogictests -- ddl --complete ``` -#### Running tests: `scratchdir` +## Running tests: `scratchdir` The DataFusion sqllogictest runner automatically creates a directory named `test_files/scratch/`, creating it if needed and @@ -190,7 +197,7 @@ Tests that need to write temporary files should write (only) to this directory to ensure they do not interfere with others concurrently running tests. -#### `.slt` file format +## `.slt` file format [`sqllogictest`] was originally written for SQLite to verify the correctness of SQL queries against the SQLite engine. The format is designed @@ -233,7 +240,7 @@ query - NULL values are rendered as `NULL`, - empty strings are rendered as `(empty)`, - boolean values are rendered as `true`/`false`, - - this list can be not exhaustive, check the `datafusion/core/tests/sqllogictests/src/engines/conversion.rs` for + - this list can be not exhaustive, check the `datafusion/sqllogictest/src/engines/conversion.rs` for details. - `sort_mode`: If included, it must be one of `nosort` (**default**), `rowsort`, or `valuesort`. In `nosort` mode, the results appear in exactly the order in which they were received from the database engine. The `nosort` mode should @@ -247,7 +254,7 @@ query > :warning: It is encouraged to either apply `order by`, or use `rowsort` for queries without explicit `order by` > clauses. -##### Example +### Example ```sql # group_by_distinct diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 618e3106c629..aeb1cc4ec919 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -26,7 +26,7 @@ use futures::stream::StreamExt; use log::info; use sqllogictest::strict_column_validator; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError, Result}; const TEST_DIRECTORY: &str = "test_files/"; const PG_COMPAT_FILE_PREFIX: &str = "pg_compat_"; @@ -84,7 +84,7 @@ async fn run_tests() -> Result<()> { // Doing so is safe because each slt file runs with its own // `SessionContext` and should not have side effects (like // modifying shared state like `/tmp/`) - let errors: Vec<_> = futures::stream::iter(read_test_files(&options)) + let errors: Vec<_> = futures::stream::iter(read_test_files(&options)?) .map(|test_file| { tokio::task::spawn(async move { println!("Running {:?}", test_file.relative_path); @@ -159,6 +159,7 @@ async fn run_test_file_with_postgres(test_file: TestFile) -> Result<()> { relative_path, } = test_file; info!("Running with Postgres runner: {}", path.display()); + setup_scratch_dir(&relative_path)?; let mut runner = sqllogictest::Runner::new(|| Postgres::connect(relative_path.clone())); runner.with_column_validator(strict_column_validator); @@ -188,6 +189,7 @@ async fn run_complete_file(test_file: TestFile) -> Result<()> { info!("Skipping: {}", path.display()); return Ok(()); }; + setup_scratch_dir(&relative_path)?; let mut runner = sqllogictest::Runner::new(|| async { Ok(DataFusion::new( test_ctx.session_ctx().clone(), @@ -245,30 +247,45 @@ impl TestFile { } } -fn read_test_files<'a>(options: &'a Options) -> Box + 'a> { - Box::new( - read_dir_recursive(TEST_DIRECTORY) +fn read_test_files<'a>( + options: &'a Options, +) -> Result + 'a>> { + Ok(Box::new( + read_dir_recursive(TEST_DIRECTORY)? + .into_iter() .map(TestFile::new) .filter(|f| options.check_test_file(&f.relative_path)) .filter(|f| f.is_slt_file()) .filter(|f| f.check_tpch(options)) .filter(|f| options.check_pg_compat_file(f.path.as_path())), - ) + )) } -fn read_dir_recursive>(path: P) -> Box> { - Box::new( - std::fs::read_dir(path) - .expect("Readable directory") - .map(|path| path.expect("Readable entry").path()) - .flat_map(|path| { - if path.is_dir() { - read_dir_recursive(path) - } else { - Box::new(std::iter::once(path)) - } - }), - ) +fn read_dir_recursive>(path: P) -> Result> { + let mut dst = vec![]; + read_dir_recursive_impl(&mut dst, path.as_ref())?; + Ok(dst) +} + +/// Append all paths recursively to dst +fn read_dir_recursive_impl(dst: &mut Vec, path: &Path) -> Result<()> { + let entries = std::fs::read_dir(path) + .map_err(|e| exec_datafusion_err!("Error reading directory {path:?}: {e}"))?; + for entry in entries { + let path = entry + .map_err(|e| { + exec_datafusion_err!("Error reading entry in directory {path:?}: {e}") + })? + .path(); + + if path.is_dir() { + read_dir_recursive_impl(dst, &path)?; + } else { + dst.push(path); + } + } + + Ok(()) } /// Parsed command line options diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs index 663bbdd5a3c7..8e2bbbfe4f69 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs @@ -21,5 +21,4 @@ mod normalize; mod runner; pub use error::*; -pub use normalize::*; pub use runner::*; diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 9af2de1af49e..a5ce7ccb9fe0 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -15,30 +15,33 @@ // specific language governing permissions and limitations // under the License. -use async_trait::async_trait; +use std::collections::HashMap; +use std::fs::File; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, + StringArray, TimestampNanosecondArray, +}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use arrow::record_batch::RecordBatch; use datafusion::execution::context::SessionState; -use datafusion::logical_expr::Expr; +use datafusion::logical_expr::{create_udf, Expr, ScalarUDF, Volatility}; +use datafusion::physical_expr::functions::make_scalar_function; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionConfig; use datafusion::{ - arrow::{ - array::{ - BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, - StringArray, TimestampNanosecondArray, - }, - datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}, - record_batch::RecordBatch, - }, catalog::{schema::MemorySchemaProvider, CatalogProvider, MemoryCatalogProvider}, datasource::{MemTable, TableProvider, TableType}, prelude::{CsvReadOptions, SessionContext}, }; +use datafusion_common::cast::as_float64_array; use datafusion_common::DataFusionError; + +use async_trait::async_trait; use log::info; -use std::fs::File; -use std::io::Write; -use std::path::Path; -use std::sync::Arc; use tempfile::TempDir; /// Context for running tests @@ -57,8 +60,8 @@ impl TestContext { } } - /// Create a SessionContext, configured for the specific test, if - /// possible. + /// Create a SessionContext, configured for the specific sqllogictest + /// test(.slt file) , if possible. /// /// If `None` is returned (e.g. because some needed feature is not /// enabled), the file should be skipped @@ -67,7 +70,7 @@ impl TestContext { // hardcode target partitions so plans are deterministic .with_target_partitions(4); - let test_ctx = TestContext::new(SessionContext::with_config(config)); + let mut test_ctx = TestContext::new(SessionContext::new_with_config(config)); let file_name = relative_path.file_name().unwrap().to_str().unwrap(); match file_name { @@ -83,13 +86,15 @@ impl TestContext { info!("Registering table with many types"); register_table_with_many_types(test_ctx.session_ctx()).await; } + "map.slt" => { + info!("Registering table with map"); + register_table_with_map(test_ctx.session_ctx()).await; + } "avro.slt" => { #[cfg(feature = "avro")] { - let mut test_ctx = test_ctx; info!("Registering avro tables"); register_avro_tables(&mut test_ctx).await; - return Some(test_ctx); } #[cfg(not(feature = "avro"))] { @@ -99,10 +104,13 @@ impl TestContext { } "joins.slt" => { info!("Registering partition table tables"); - - let mut test_ctx = test_ctx; + let example_udf = create_example_udf(); + test_ctx.ctx.register_udf(example_udf); register_partition_table(&mut test_ctx).await; - return Some(test_ctx); + } + "metadata.slt" => { + info!("Registering metadata table tables"); + register_metadata_tables(test_ctx.session_ctx()).await; } _ => { info!("Using default SessionContext"); @@ -268,6 +276,23 @@ pub async fn register_table_with_many_types(ctx: &SessionContext) { .unwrap(); } +pub async fn register_table_with_map(ctx: &SessionContext) { + let key = Field::new("key", DataType::Int64, false); + let value = Field::new("value", DataType::Int64, true); + let map_field = + Field::new("entries", DataType::Struct(vec![key, value].into()), false); + let fields = vec![ + Field::new("int_field", DataType::Int64, true), + Field::new("map_field", DataType::Map(map_field.into(), false), true), + ]; + let schema = Schema::new(fields); + + let memory_table = MemTable::try_new(schema.into(), vec![vec![]]).unwrap(); + + ctx.register_table("table_with_map", Arc::new(memory_table)) + .unwrap(); +} + fn table_with_many_types() -> Arc { let schema = Schema::new(vec![ Field::new("int32_col", DataType::Int32, false), @@ -299,3 +324,58 @@ fn table_with_many_types() -> Arc { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); Arc::new(provider) } + +/// Registers a table_with_metadata that contains both field level and Table level metadata +pub async fn register_metadata_tables(ctx: &SessionContext) { + let id = Field::new("id", DataType::Int32, true).with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("the id field"), + )])); + let name = Field::new("name", DataType::Utf8, true).with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("the name field"), + )])); + + let schema = Schema::new(vec![id, name]).with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("the entire schema"), + )])); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _, + Arc::new(StringArray::from(vec![None, Some("bar"), Some("baz")])) as _, + ], + ) + .unwrap(); + + ctx.register_batch("table_with_metadata", batch).unwrap(); +} + +/// Create a UDF function named "example". See the `sample_udf.rs` example +/// file for an explanation of the API. +fn create_example_udf() -> ScalarUDF { + let adder = make_scalar_function(|args: &[ArrayRef]| { + let lhs = as_float64_array(&args[0]).expect("cast failed"); + let rhs = as_float64_array(&args[1]).expect("cast failed"); + let array = lhs + .iter() + .zip(rhs.iter()) + .map(|(lhs, rhs)| match (lhs, rhs) { + (Some(lhs), Some(rhs)) => Some(lhs + rhs), + _ => None, + }) + .collect::(); + Ok(Arc::new(array) as ArrayRef) + }); + create_udf( + "example", + // Expects two f64 values: + vec![DataType::Float64, DataType::Float64], + // Returns an f64 value: + Arc::new(DataType::Float64), + Volatility::Immutable, + adder, + ) +} diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index d0e41b12b8c9..aa512f6e2600 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -106,6 +106,36 @@ FROM ---- [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB, 0og6hSkhbX8AC1ktFS4kounvTzy8Vo, 1aOcrEGd0cOqZe2I5XBOm0nDcwtBZO, 2T3wSlHdEmASmO0xcXHnndkKEt6bz8] +statement ok +CREATE EXTERNAL TABLE agg_order ( +c1 INT NOT NULL, +c2 INT NOT NULL, +c3 INT NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../core/tests/data/aggregate_agg_multi_order.csv'; + +# test array_agg with order by multiple columns +query ? +select array_agg(c1 order by c2 desc, c3) from agg_order; +---- +[5, 6, 7, 8, 9, 1, 2, 3, 4, 10] + +query TT +explain select array_agg(c1 order by c2 desc, c3) from agg_order; +---- +logical_plan +Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]] +--TableScan: agg_order projection=[c1, c2, c3] +physical_plan +AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(agg_order.c1)] +--CoalescePartitionsExec +----AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(agg_order.c1)] +------SortExec: expr=[c2@1 DESC,c3@2 ASC NULLS LAST] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_agg_multi_order.csv]]}, projection=[c1, c2, c3], has_header=true + statement error This feature is not implemented: LIMIT not supported in ARRAY_AGG: 1 SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100 @@ -1327,36 +1357,128 @@ select avg(c1), arrow_typeof(avg(c1)) from d_table ---- 5 Decimal128(14, 7) -# FIX: different test table + # aggregate -# query I -# SELECT SUM(c1), SUM(c2) FROM test -# ---- -# 60 220 +query II +SELECT SUM(c1), SUM(c2) FROM test +---- +7 6 + +# aggregate_empty + +query II +SELECT SUM(c1), SUM(c2) FROM test where c1 > 100000 +---- +NULL NULL + +# aggregate_avg +query RR +SELECT AVG(c1), AVG(c2) FROM test +---- +1.75 1.5 + +# aggregate_max +query II +SELECT MAX(c1), MAX(c2) FROM test +---- +3 2 + +# aggregate_min +query II +SELECT MIN(c1), MIN(c2) FROM test +---- +0 1 -# TODO: aggregate_empty +# aggregate_grouped +query II +SELECT c1, SUM(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 4 +NULL 1 -# TODO: aggregate_avg +# aggregate_grouped_avg +query IR +SELECT c1, AVG(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 2 +NULL 1 -# TODO: aggregate_max +# aggregate_grouped_empty +query IR +SELECT c1, AVG(c2) FROM test WHERE c1 = 123 GROUP BY c1 +---- -# TODO: aggregate_min +# aggregate_grouped_max +query II +SELECT c1, MAX(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 2 +NULL 1 -# TODO: aggregate_grouped +# aggregate_grouped_min +query II +SELECT c1, MIN(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 2 +NULL 1 -# TODO: aggregate_grouped_avg +# aggregate_min_max_w_custom_window_frames +query RR +SELECT +MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN 0.3 PRECEDING AND 0.2 FOLLOWING) as min1, +MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN 0.1 PRECEDING AND 0.2 FOLLOWING) as max1 +FROM aggregate_test_100 +ORDER BY C9 +LIMIT 5 +---- +0.014793053078 0.996540038759 +0.014793053078 0.980019341044 +0.014793053078 0.970671228336 +0.266717779508 0.996540038759 +0.360076636233 0.970671228336 -# TODO: aggregate_grouped_empty +# aggregate_min_max_with_custom_window_frames_unbounded_start +query RR +SELECT +MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as min1, +MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as max1 +FROM aggregate_test_100 +ORDER BY C9 +LIMIT 5 +---- +0.014793053078 0.996540038759 +0.014793053078 0.980019341044 +0.014793053078 0.980019341044 +0.014793053078 0.996540038759 +0.014793053078 0.980019341044 -# TODO: aggregate_grouped_max +# aggregate_avg_add +query RRRR +SELECT AVG(c1), AVG(c1) + 1, AVG(c1) + 2, 1 + AVG(c1) FROM test +---- +1.75 2.75 3.75 2.75 -# TODO: aggregate_grouped_min +# case_sensitive_identifiers_aggregates +query I +SELECT max(c1) FROM test; +---- +3 -# TODO: aggregate_avg_add -# TODO: case_sensitive_identifiers_aggregates -# TODO: count_basic +# count_basic +query II +SELECT COUNT(c1), COUNT(c2) FROM test +---- +4 4 # TODO: count_partitioned @@ -1364,9 +1486,59 @@ select avg(c1), arrow_typeof(avg(c1)) from d_table # TODO: count_aggregated_cube -# TODO: simple_avg +# count_multi_expr +query I +SELECT count(c1, c2) FROM test +---- +3 + +# count_null +query III +SELECT count(null), count(null, null), count(distinct null) FROM test +---- +0 0 0 + +# count_multi_expr_group_by +query I +SELECT count(c1, c2) FROM test group by c1 order by c1 +---- +0 +1 +2 +0 + +# count_null_group_by +query III +SELECT count(null), count(null, null), count(distinct null) FROM test group by c1 order by c1 +---- +0 0 0 +0 0 0 +0 0 0 +0 0 0 + +# aggreggte_with_alias +query II +select c1, sum(c2) as `Total Salary` from test group by c1 order by c1 +---- +0 NULL +1 1 +3 4 +NULL 1 + +# simple_avg + +query R +select avg(c1) from test +---- +1.75 + +# simple_mean +query R +select mean(c1) from test +---- +1.75 + -# TODO: simple_mean # query_sum_distinct - 2 different aggregate functions: avg and sum(distinct) query RI @@ -1396,7 +1568,7 @@ SELECT COUNT(DISTINCT c1) FROM test query ? SELECT ARRAY_AGG([]) ---- -[] +[[]] # array_agg_one query ? @@ -1419,7 +1591,7 @@ e 4 query ? SELECT ARRAY_AGG([]); ---- -[] +[[]] # array_agg_one query ? @@ -2020,14 +2192,6 @@ statement ok drop table t; - - -statement error DataFusion error: Execution error: Table 't_source' doesn't exist\. -drop table t_source; - -statement error DataFusion error: Execution error: Table 't' doesn't exist\. -drop table t; - query I select median(a) from (select 1 as a where 1=0); ---- @@ -2199,6 +2363,26 @@ NULL 1 10.1 10.1 10.1 10.1 0 NULL statement ok set datafusion.sql_parser.dialect = 'Generic'; +## Multiple distinct aggregates and dictionaries +statement ok +create table dict_test as values (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('bar', 'Dictionary(Int32, Utf8)')); + +query I? +select * from dict_test; +---- +1 foo +2 bar + +query II +select count(distinct column1), count(distinct column2) from dict_test group by column1; +---- +1 1 +1 1 + +statement ok +drop table dict_test; + + # Prepare the table with dictionary values for testing statement ok CREATE TABLE value(x bigint) AS VALUES (1), (2), (3), (1), (3), (4), (5), (2); @@ -2282,6 +2466,22 @@ select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict); 4 5 +query T +select arrow_typeof(x_dict) from value_dict group by x_dict; +---- +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) + +statement ok +drop table value + +statement ok +drop table value_dict + + # bool aggregation statement ok CREATE TABLE value_bool(x boolean, g int) AS VALUES (NULL, 0), (false, 0), (true, 0), (false, 1), (true, 2), (NULL, 3); @@ -2318,6 +2518,7 @@ CREATE TABLE traces(trace_id varchar, timestamp bigint, other bigint) AS VALUES (NULL, 0, 0), ('a', NULL, NULL), ('a', 1, 1), +('a', -1, -1), ('b', 0, 0), ('c', 1, 1), ('c', 2, 2), @@ -2337,12 +2538,12 @@ Limit: skip=0, fetch=4 physical_plan GlobalLimitExec: skip=0, fetch=4 --SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4 -----SortExec: fetch=4, expr=[MAX(traces.timestamp)@1 DESC] +----SortExec: TopK(fetch=4), expr=[MAX(traces.timestamp)@1 DESC] ------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 -------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] ---------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] ----------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2357,26 +2558,26 @@ NULL 0 query TI select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 4; ---- +a -1 NULL 0 b 0 c 1 -a 1 query TII select trace_id, other, MIN(timestamp) from traces group by trace_id, other order by MIN(timestamp) asc limit 4; ---- +a -1 -1 b 0 0 NULL 0 0 c 1 1 -a 1 1 query TII select trace_id, MIN(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4; ---- +a -1 -1 NULL 0 0 b 0 0 c 1 1 -a 1 1 statement ok set datafusion.optimizer.enable_topk_aggregation = true; @@ -2392,12 +2593,12 @@ Limit: skip=0, fetch=4 physical_plan GlobalLimitExec: skip=0, fetch=4 --SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4 -----SortExec: fetch=4, expr=[MAX(traces.timestamp)@1 DESC] +----SortExec: TopK(fetch=4), expr=[MAX(traces.timestamp)@1 DESC] ------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)], lim=[4] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 -------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)], lim=[4] ---------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)], lim=[4] ----------------MemoryExec: partitions=1, partition_sizes=[1] query TT @@ -2411,12 +2612,12 @@ Limit: skip=0, fetch=4 physical_plan GlobalLimitExec: skip=0, fetch=4 --SortPreservingMergeExec: [MIN(traces.timestamp)@1 DESC], fetch=4 -----SortExec: fetch=4, expr=[MIN(traces.timestamp)@1 DESC] +----SortExec: TopK(fetch=4), expr=[MIN(traces.timestamp)@1 DESC] ------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MIN(traces.timestamp)] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 -------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MIN(traces.timestamp)] ---------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MIN(traces.timestamp)] ----------------MemoryExec: partitions=1, partition_sizes=[1] query TT @@ -2430,12 +2631,12 @@ Limit: skip=0, fetch=4 physical_plan GlobalLimitExec: skip=0, fetch=4 --SortPreservingMergeExec: [MAX(traces.timestamp)@1 ASC NULLS LAST], fetch=4 -----SortExec: fetch=4, expr=[MAX(traces.timestamp)@1 ASC NULLS LAST] +----SortExec: TopK(fetch=4), expr=[MAX(traces.timestamp)@1 ASC NULLS LAST] ------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 -------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] ---------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] ----------------MemoryExec: partitions=1, partition_sizes=[1] query TT @@ -2449,12 +2650,12 @@ Limit: skip=0, fetch=4 physical_plan GlobalLimitExec: skip=0, fetch=4 --SortPreservingMergeExec: [trace_id@0 ASC NULLS LAST], fetch=4 -----SortExec: fetch=4, expr=[trace_id@0 ASC NULLS LAST] +----SortExec: TopK(fetch=4), expr=[trace_id@0 ASC NULLS LAST] ------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 -------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] ---------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] ----------------MemoryExec: partitions=1, partition_sizes=[1] query TI @@ -2468,10 +2669,10 @@ NULL 0 query TI select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 4; ---- +a -1 NULL 0 b 0 c 1 -a 1 query TI select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 3; @@ -2483,25 +2684,223 @@ a 1 query TI select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 3; ---- +a -1 NULL 0 b 0 -c 1 query TII select trace_id, other, MIN(timestamp) from traces group by trace_id, other order by MIN(timestamp) asc limit 4; ---- +a -1 -1 b 0 0 NULL 0 0 c 1 1 -a 1 1 query TII select trace_id, MIN(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4; ---- +a -1 -1 NULL 0 0 b 0 0 c 1 1 -a 1 1 + +# +# Push limit into distinct group-by aggregation tests +# + +# Make results deterministic +statement ok +set datafusion.optimizer.repartition_aggregations = false; + +# +query TT +EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +------TableScan: aggregate_test_100 projection=[c3] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5] +------------CoalescePartitionsExec +--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5] +----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true + +query I +SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; +---- +1 +-40 +29 +-85 +-82 + +query TT +EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5 offset 4; +---- +logical_plan +Limit: skip=4, fetch=5 +--Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[]] +----TableScan: aggregate_test_100 projection=[c2, c3] +physical_plan +GlobalLimitExec: skip=4, fetch=5 +--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[9] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[9] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query II +SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5 offset 4; +---- +5 -82 +4 -111 +3 104 +3 13 +1 38 + +# The limit should only apply to the aggregations which group by c3 +query TT +EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c2, c3 limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +----Projection: aggregate_test_100.c3 +------Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[]] +--------Filter: aggregate_test_100.c3 >= Int16(10) AND aggregate_test_100.c3 <= Int16(20) +----------TableScan: aggregate_test_100 projection=[c2, c3], partial_filters=[aggregate_test_100.c3 >= Int16(10), aggregate_test_100.c3 <= Int16(20)] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[4] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[4] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------ProjectionExec: expr=[c3@1 as c3] +------------AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +--------------CoalescePartitionsExec +----------------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------FilterExec: c3@1 >= 10 AND c3@1 <= 20 +----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query I +SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c2, c3 limit 4; +---- +13 +17 +12 +14 + +# An aggregate expression causes the limit to not be pushed to the aggregation +query TT +EXPLAIN SELECT max(c1), c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5; +---- +logical_plan +Projection: MAX(aggregate_test_100.c1), aggregate_test_100.c2, aggregate_test_100.c3 +--Limit: skip=0, fetch=5 +----Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[MAX(aggregate_test_100.c1)]] +------TableScan: aggregate_test_100 projection=[c1, c2, c3] +physical_plan +ProjectionExec: expr=[MAX(aggregate_test_100.c1)@2 as MAX(aggregate_test_100.c1), c2@0 as c2, c3@1 as c3] +--GlobalLimitExec: skip=0, fetch=5 +----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[MAX(aggregate_test_100.c1)] +------CoalescePartitionsExec +--------AggregateExec: mode=Partial, gby=[c2@1 as c2, c3@2 as c3], aggr=[MAX(aggregate_test_100.c1)] +----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true + +# TODO(msirek): Extend checking in LimitedDistinctAggregation equal groupings to ignore the order of columns +# in the group-by column lists, so the limit could be pushed to the lowest AggregateExec in this case +query TT +EXPLAIN SELECT DISTINCT c3, c2 FROM aggregate_test_100 group by c2, c3 limit 3 offset 10; +---- +logical_plan +Limit: skip=10, fetch=3 +--Aggregate: groupBy=[[aggregate_test_100.c3, aggregate_test_100.c2]], aggr=[[]] +----Projection: aggregate_test_100.c3, aggregate_test_100.c2 +------Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[]] +--------TableScan: aggregate_test_100 projection=[c2, c3] +physical_plan +GlobalLimitExec: skip=10, fetch=3 +--AggregateExec: mode=Final, gby=[c3@0 as c3, c2@1 as c2], aggr=[], lim=[13] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3, c2@1 as c2], aggr=[], lim=[13] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------ProjectionExec: expr=[c3@1 as c3, c2@0 as c2] +------------AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +--------------CoalescePartitionsExec +----------------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query II +SELECT DISTINCT c3, c2 FROM aggregate_test_100 group by c2, c3 limit 3 offset 10; +---- +57 1 +-54 4 +112 3 + +query TT +EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; +---- +logical_plan +Limit: skip=0, fetch=3 +--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] +----TableScan: aggregate_test_100 projection=[c2, c3] +physical_plan +GlobalLimitExec: skip=0, fetch=3 +--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query II +SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; +---- +NULL NULL +2 NULL +5 NULL + + +statement ok +set datafusion.optimizer.enable_distinct_aggregation_soft_limit = false; + +# The limit should not be pushed into the aggregations +query TT +EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +------TableScan: aggregate_test_100 projection=[c3] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[] +------------CoalescePartitionsExec +--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[] +----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true + +statement ok +set datafusion.optimizer.enable_distinct_aggregation_soft_limit = true; + +statement ok +set datafusion.optimizer.repartition_aggregations = true; # # regr_*() tests @@ -2769,3 +3168,92 @@ NULL NULL 1 NULL 3 6 0 0 0 NULL NULL 1 NULL 5 15 0 0 0 3 0 2 1 5.5 16.5 0.5 4.5 1.5 3 0 3 1 6 18 2 18 6 + +statement error +SELECT STRING_AGG() + +statement error +SELECT STRING_AGG(1,2,3) + +statement error +SELECT STRING_AGG(STRING_AGG('a', ',')) + +query T +SELECT STRING_AGG('a', ',') +---- +a + +query TTTT +SELECT STRING_AGG('a',','), STRING_AGG('a', NULL), STRING_AGG(NULL, ','), STRING_AGG(NULL, NULL) +---- +a a NULL NULL + +query TT +select string_agg('', '|'), string_agg('a', ''); +---- +(empty) a + +query T +SELECT STRING_AGG(column1, '|') FROM (values (''), (null), ('')); +---- +| + +statement ok +CREATE TABLE strings(g INTEGER, x VARCHAR, y VARCHAR) + +query ITT +INSERT INTO strings VALUES (1,'a','/'), (1,'b','-'), (2,'i','/'), (2,NULL,'-'), (2,'j','+'), (3,'p','/'), (4,'x','/'), (4,'y','-'), (4,'z','+') +---- +9 + +query IT +SELECT g, STRING_AGG(x,'|') FROM strings GROUP BY g ORDER BY g +---- +1 a|b +2 i|j +3 p +4 x|y|z + +query T +SELECT STRING_AGG(x,',') FROM strings WHERE g > 100 +---- +NULL + +statement ok +drop table strings + +query T +WITH my_data as ( +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column +) +SELECT string_agg(my_column,', ') as my_string_agg +FROM my_data +---- +text1, text1, text1 + +query T +WITH my_data as ( +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column +) +SELECT string_agg(my_column,', ') as my_string_agg +FROM my_data +GROUP BY dummy +---- +text1, text1, text1 + + +# Queries with nested count(*) + +query I +select count(*) from (select count(*) from (select 1)); +---- +1 + +query I +select count(*) from (select count(*) a, count(*) b from (select 1)); +---- +1 diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index f11bc5206eb4..d864091a8588 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -67,6 +67,16 @@ AS VALUES (make_array(make_array(15, 16),make_array(NULL, 18)), make_array(16.6, 17.7, 18.8), NULL) ; +statement ok +CREATE TABLE large_arrays +AS + SELECT + arrow_cast(column1, 'LargeList(List(Int64))') AS column1, + arrow_cast(column2, 'LargeList(Float64)') AS column2, + arrow_cast(column3, 'LargeList(Utf8)') AS column3 + FROM arrays +; + statement ok CREATE TABLE slices AS VALUES @@ -97,6 +107,19 @@ AS VALUES (make_array(make_array(4, 5, 6), make_array(10, 11, 12), make_array(4, 9, 8), make_array(7, 8, 9), make_array(10, 11, 12), make_array(1, 8, 7)), make_array(10, 11, 12), 3, make_array([[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]]), make_array(121, 131, 141)) ; +# TODO: add this when #8305 is fixed +# statement ok +# CREATE TABLE large_nested_arrays +# AS +# SELECT +# arrow_cast(column1, 'LargeList(LargeList(Int64))') AS column1, +# arrow_cast(column2, 'LargeList(Int64)') AS column2, +# column3, +# arrow_cast(column4, 'LargeList(LargeList(List(Int64)))') AS column4, +# arrow_cast(column5, 'LargeList(Int64)') AS column5 +# FROM nested_arrays +# ; + statement ok CREATE TABLE arrays_values AS VALUES @@ -110,6 +133,17 @@ AS VALUES (make_array(61, 62, 63, 64, 65, 66, 67, 68, 69, 70), 66, 7, NULL) ; +statement ok +CREATE TABLE large_arrays_values +AS SELECT + arrow_cast(column1, 'LargeList(Int64)') AS column1, + column2, + column3, + column4 +FROM arrays_values +; + + statement ok CREATE TABLE arrays_values_v2 AS VALUES @@ -121,6 +155,17 @@ AS VALUES (NULL, NULL, NULL, NULL) ; +# TODO: add this when #8305 is fixed +# statement ok +# CREATE TABLE large_arrays_values_v2 +# AS SELECT +# arrow_cast(column1, 'LargeList(Int64)') AS column1, +# arrow_cast(column2, 'LargeList(Int64)') AS column2, +# column3, +# arrow_cast(column4, 'LargeList(LargeList(Int64))') AS column4 +# FROM arrays_values_v2 +# ; + statement ok CREATE TABLE flatten_table AS VALUES @@ -182,6 +227,168 @@ AS VALUES (make_array([[1], [2]], [[2], [3]]), make_array([1], [2])) ; +statement ok +CREATE TABLE array_distinct_table_1D +AS VALUES + (make_array(1, 1, 2, 2, 3)), + (make_array(1, 2, 3, 4, 5)), + (make_array(3, 5, 3, 3, 3)) +; + +statement ok +CREATE TABLE array_distinct_table_1D_UTF8 +AS VALUES + (make_array('a', 'a', 'bc', 'bc', 'def')), + (make_array('a', 'bc', 'def', 'defg', 'defg')), + (make_array('defg', 'defg', 'defg', 'defg', 'defg')) +; + +statement ok +CREATE TABLE array_distinct_table_2D +AS VALUES + (make_array([1,2], [1,2], [3,4], [3,4], [5,6])), + (make_array([1,2], [3,4], [5,6], [7,8], [9,10])), + (make_array([5,6], [5,6], NULL)) +; + +statement ok +CREATE TABLE array_distinct_table_1D_large +AS VALUES + (arrow_cast(make_array(1, 1, 2, 2, 3), 'LargeList(Int64)')), + (arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), + (arrow_cast(make_array(3, 5, 3, 3, 3), 'LargeList(Int64)')) +; + +statement ok +CREATE TABLE array_intersect_table_1D +AS VALUES + (make_array(1, 2), make_array(1), make_array(1,2,3), make_array(1,3), make_array(1,3,5), make_array(2,4,6,8,1,3)), + (make_array(11, 22), make_array(11), make_array(11,22,33), make_array(11,33), make_array(11,33,55), make_array(22,44,66,88,11,33)) +; + +statement ok +CREATE TABLE large_array_intersect_table_1D +AS + SELECT + arrow_cast(column1, 'LargeList(Int64)') as column1, + arrow_cast(column2, 'LargeList(Int64)') as column2, + arrow_cast(column3, 'LargeList(Int64)') as column3, + arrow_cast(column4, 'LargeList(Int64)') as column4, + arrow_cast(column5, 'LargeList(Int64)') as column5, + arrow_cast(column6, 'LargeList(Int64)') as column6 +FROM array_intersect_table_1D +; + +statement ok +CREATE TABLE array_intersect_table_1D_Float +AS VALUES + (make_array(1.0, 2.0), make_array(1.0), make_array(1.0,2.0,3.0), make_array(1.0,3.0), make_array(1.11), make_array(2.22, 3.33)), + (make_array(3.0, 4.0, 5.0), make_array(2.0), make_array(1.0,2.0,3.0,4.0), make_array(2.0,5.0), make_array(2.22, 1.11), make_array(1.11, 3.33)) +; + +statement ok +CREATE TABLE large_array_intersect_table_1D_Float +AS + SELECT + arrow_cast(column1, 'LargeList(Float64)') as column1, + arrow_cast(column2, 'LargeList(Float64)') as column2, + arrow_cast(column3, 'LargeList(Float64)') as column3, + arrow_cast(column4, 'LargeList(Float64)') as column4, + arrow_cast(column5, 'LargeList(Float64)') as column5, + arrow_cast(column6, 'LargeList(Float64)') as column6 +FROM array_intersect_table_1D_Float +; + +statement ok +CREATE TABLE array_intersect_table_1D_Boolean +AS VALUES + (make_array(true, true, true), make_array(false), make_array(true, true, false, true, false), make_array(true, false, true), make_array(false), make_array(true, false)), + (make_array(false, false, false), make_array(false), make_array(true, false, true), make_array(true, true), make_array(true, true), make_array(false,false,true)) +; + +statement ok +CREATE TABLE large_array_intersect_table_1D_Boolean +AS + SELECT + arrow_cast(column1, 'LargeList(Boolean)') as column1, + arrow_cast(column2, 'LargeList(Boolean)') as column2, + arrow_cast(column3, 'LargeList(Boolean)') as column3, + arrow_cast(column4, 'LargeList(Boolean)') as column4, + arrow_cast(column5, 'LargeList(Boolean)') as column5, + arrow_cast(column6, 'LargeList(Boolean)') as column6 +FROM array_intersect_table_1D_Boolean +; + +statement ok +CREATE TABLE array_intersect_table_1D_UTF8 +AS VALUES + (make_array('a', 'bc', 'def'), make_array('bc'), make_array('datafusion', 'rust', 'arrow'), make_array('rust', 'arrow'), make_array('rust', 'arrow', 'python'), make_array('data')), + (make_array('a', 'bc', 'def'), make_array('defg'), make_array('datafusion', 'rust', 'arrow'), make_array('datafusion', 'rust', 'arrow', 'python'), make_array('rust', 'arrow'), make_array('datafusion', 'rust', 'arrow')) +; + +statement ok +CREATE TABLE large_array_intersect_table_1D_UTF8 +AS + SELECT + arrow_cast(column1, 'LargeList(Utf8)') as column1, + arrow_cast(column2, 'LargeList(Utf8)') as column2, + arrow_cast(column3, 'LargeList(Utf8)') as column3, + arrow_cast(column4, 'LargeList(Utf8)') as column4, + arrow_cast(column5, 'LargeList(Utf8)') as column5, + arrow_cast(column6, 'LargeList(Utf8)') as column6 +FROM array_intersect_table_1D_UTF8 +; + +statement ok +CREATE TABLE array_intersect_table_2D +AS VALUES + (make_array([1,2]), make_array([1,3]), make_array([1,2,3], [4,5], [6,7]), make_array([4,5], [6,7])), + (make_array([3,4], [5]), make_array([3,4]), make_array([1,2,3,4], [5,6,7], [8,9,10]), make_array([1,2,3], [5,6,7], [8,9,10])) +; + +statement ok +CREATE TABLE large_array_intersect_table_2D +AS + SELECT + arrow_cast(column1, 'LargeList(List(Int64))') as column1, + arrow_cast(column2, 'LargeList(List(Int64))') as column2, + arrow_cast(column3, 'LargeList(List(Int64))') as column3, + arrow_cast(column4, 'LargeList(List(Int64))') as column4 +FROM array_intersect_table_2D +; + +statement ok +CREATE TABLE array_intersect_table_2D_float +AS VALUES + (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.1, 2.2], [3.3])), + (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.0], [1.1, 2.2], [3.3])) +; + +statement ok +CREATE TABLE large_array_intersect_table_2D_Float +AS + SELECT + arrow_cast(column1, 'LargeList(List(Float64))') as column1, + arrow_cast(column2, 'LargeList(List(Float64))') as column2 +FROM array_intersect_table_2D_Float +; + +statement ok +CREATE TABLE array_intersect_table_3D +AS VALUES + (make_array([[1,2]]), make_array([[1]])), + (make_array([[1,2]]), make_array([[1,2]])) +; + +statement ok +CREATE TABLE large_array_intersect_table_3D +AS + SELECT + arrow_cast(column1, 'LargeList(List(List(Int64)))') as column1, + arrow_cast(column2, 'LargeList(List(List(Int64)))') as column2 +FROM array_intersect_table_3D +; + statement ok CREATE TABLE arrays_values_without_nulls AS VALUES @@ -191,6 +398,24 @@ AS VALUES (make_array(31, 32, 33, 34, 35, 26, 37, 38, 39, 40), 34, 4, 'ok', [8,9]) ; +statement ok +CREATE TABLE large_arrays_values_without_nulls +AS SELECT + arrow_cast(column1, 'LargeList(Int64)') AS column1, + column2, + column3, + column4, + arrow_cast(column5, 'LargeList(Int64)') AS column5 +FROM arrays_values_without_nulls +; + +statement ok +CREATE TABLE arrays_range +AS VALUES + (3, 10, 2), + (4, 13, 3) +; + statement ok CREATE TABLE arrays_with_repeating_elements AS VALUES @@ -200,6 +425,17 @@ AS VALUES (make_array(10, 11, 12, 10, 11, 12, 10, 11, 12, 10), 10, 13, 10) ; +statement ok +CREATE TABLE large_arrays_with_repeating_elements +AS + SELECT + arrow_cast(column1, 'LargeList(Int64)') AS column1, + column2, + column3, + column4 + FROM arrays_with_repeating_elements +; + statement ok CREATE TABLE nested_arrays_with_repeating_elements AS VALUES @@ -209,6 +445,34 @@ AS VALUES (make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10) ; +statement ok +CREATE TABLE large_nested_arrays_with_repeating_elements +AS + SELECT + arrow_cast(column1, 'LargeList(List(Int64))') AS column1, + column2, + column3, + column4 + FROM nested_arrays_with_repeating_elements +; + +query error +select [1, true, null] + +query error DataFusion error: This feature is not implemented: ScalarFunctions without MakeArray are not supported: now() +SELECT [now()] + +query TTT +select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) from arrays; +---- +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + # arrays table query ??? select column1, column2, column3 from arrays; @@ -612,7 +876,7 @@ from arrays_values_without_nulls; ## array_element (aliases: array_extract, list_extract, list_element) # array_element error -query error DataFusion error: Error during planning: The array_element function can only accept list as the first argument +query error DataFusion error: Error during planning: The array_element function can only accept list or largelist as the first argument select array_element(1, 2); @@ -622,58 +886,106 @@ select array_element(make_array(1, 2, 3, 4, 5), 2), array_element(make_array('h' ---- 2 l +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # array_element scalar function #2 (with positive index; out of bounds) query IT select array_element(make_array(1, 2, 3, 4, 5), 7), array_element(make_array('h', 'e', 'l', 'l', 'o'), 11); ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 7), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 11); +---- +NULL NULL + # array_element scalar function #3 (with zero) query IT select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h', 'e', 'l', 'l', 'o'), 0); ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0); +---- +NULL NULL + # array_element scalar function #4 (with NULL) -query error +query error select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL); +query error +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL); + # array_element scalar function #5 (with negative index) query IT select array_element(make_array(1, 2, 3, 4, 5), -2), array_element(make_array('h', 'e', 'l', 'l', 'o'), -3); ---- 4 l +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3); +---- +4 l + # array_element scalar function #6 (with negative index; out of bounds) query IT select array_element(make_array(1, 2, 3, 4, 5), -11), array_element(make_array('h', 'e', 'l', 'l', 'o'), -7); ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -11), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -7); +---- +NULL NULL + # array_element scalar function #7 (nested array) query ? select array_element(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1); ---- [1, 2, 3, 4, 5] +query ? +select array_element(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), 1); +---- +[1, 2, 3, 4, 5] + # array_extract scalar function #8 (function alias `array_slice`) query IT select array_extract(make_array(1, 2, 3, 4, 5), 2), array_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- 2 l +query IT +select array_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # list_element scalar function #9 (function alias `array_slice`) query IT select list_element(make_array(1, 2, 3, 4, 5), 2), list_element(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- 2 l +query IT +select list_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # list_extract scalar function #10 (function alias `array_slice`) query IT select list_extract(make_array(1, 2, 3, 4, 5), 2), list_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- 2 l +query IT +select list_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # array_element with columns query I select array_element(column1, column2) from slices; @@ -686,6 +998,17 @@ NULL NULL 55 +query I +select array_element(arrow_cast(column1, 'LargeList(Int64)'), column2) from slices; +---- +NULL +12 +NULL +37 +NULL +NULL +55 + # array_element with columns and scalars query II select array_element(make_array(1, 2, 3, 4, 5), column2), array_element(column1, 3) from slices; @@ -698,6 +1021,17 @@ NULL 23 NULL 43 5 NULL +query II +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2), array_element(arrow_cast(column1, 'LargeList(Int64)'), 3) from slices; +---- +1 3 +2 13 +NULL 23 +2 33 +4 NULL +NULL 43 +5 NULL + ## array_pop_back (aliases: `list_pop_back`) # array_pop_back scalar function #1 @@ -706,18 +1040,33 @@ select array_pop_back(make_array(1, 2, 3, 4, 5)), array_pop_back(make_array('h', ---- [1, 2, 3, 4] [h, e, l, l] +query ?? +select array_pop_back(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), array_pop_back(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [h, e, l, l] + # array_pop_back scalar function #2 (after array_pop_back, array is empty) query ? select array_pop_back(make_array(1)); ---- [] +query ? +select array_pop_back(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +[] + # array_pop_back scalar function #3 (array_pop_back the empty array) query ? select array_pop_back(array_pop_back(make_array(1))); ---- [] +query ? +select array_pop_back(array_pop_back(arrow_cast(make_array(1), 'LargeList(Int64)'))); +---- +[] + # array_pop_back scalar function #4 (array_pop_back the arrays which have NULL) query ?? select array_pop_back(make_array(1, 2, 3, 4, NULL)), array_pop_back(make_array(NULL, 'e', 'l', NULL, 'o')); @@ -730,24 +1079,44 @@ select array_pop_back(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_ ---- [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6)), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + # array_pop_back scalar function #6 (array_pop_back the nested arrays with NULL) query ? select array_pop_back(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), NULL)); ---- [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), NULL), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + # array_pop_back scalar function #7 (array_pop_back the nested arrays with NULL) query ? select array_pop_back(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), NULL, make_array(1, 7, 4))); ---- [[1, 2, 3], [2, 9, 1], [7, 8, 9], ] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), NULL, make_array(1, 7, 4)), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], ] + # array_pop_back scalar function #8 (after array_pop_back, nested array is empty) query ? select array_pop_back(make_array(make_array(1, 2, 3))); ---- [] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3)), 'LargeList(List(Int64))')); +---- +[] + # array_pop_back with columns query ? select array_pop_back(column1) from arrayspop; @@ -759,6 +1128,84 @@ select array_pop_back(column1) from arrayspop; [] [, 10, 11] +query ? +select array_pop_back(arrow_cast(column1, 'LargeList(Int64)')) from arrayspop; +---- +[1, 2] +[3, 4, 5] +[6, 7, 8, ] +[, ] +[] +[, 10, 11] + +## array_pop_front (aliases: `list_pop_front`) + +# array_pop_front scalar function #1 +query ?? +select array_pop_front(make_array(1, 2, 3, 4, 5)), array_pop_front(make_array('h', 'e', 'l', 'l', 'o')); +---- +[2, 3, 4, 5] [e, l, l, o] + +query ?? +select array_pop_front(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), array_pop_front(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[2, 3, 4, 5] [e, l, l, o] + +# array_pop_front scalar function #2 (after array_pop_front, array is empty) +query ? +select array_pop_front(make_array(1)); +---- +[] + +query ? +select array_pop_front(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +[] + +# array_pop_front scalar function #3 (array_pop_front the empty array) +query ? +select array_pop_front(array_pop_front(make_array(1))); +---- +[] + +query ? +select array_pop_front(array_pop_front(arrow_cast(make_array(1), 'LargeList(Int64)'))); +---- +[] + +# array_pop_front scalar function #5 (array_pop_front the nested arrays) +query ? +select array_pop_front(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6))); +---- +[[2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] + +query ? +select array_pop_front(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6)), 'LargeList(List(Int64))')); +---- +[[2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] + +# array_pop_front scalar function #6 (array_pop_front the nested arrays with NULL) +query ? +select array_pop_front(make_array(NULL, make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4))); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + +query ? +select array_pop_front(arrow_cast(make_array(NULL, make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4)), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + +# array_pop_front scalar function #8 (after array_pop_front, nested array is empty) +query ? +select array_pop_front(make_array(make_array(1, 2, 3))); +---- +[] + +query ? +select array_pop_front(arrow_cast(make_array(make_array(1, 2, 3)), 'LargeList(List(Int64))')); +---- +[] + ## array_slice (aliases: list_slice) # array_slice scalar function #1 (with positive indexes) @@ -767,109 +1214,201 @@ select array_slice(make_array(1, 2, 3, 4, 5), 2, 4), array_slice(make_array('h', ---- [2, 3, 4] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 1, 2); +---- +[2, 3, 4] [h, e] + # array_slice scalar function #2 (with positive indexes; full array) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 5); ---- [1, 2, 3, 4, 5] [h, e, l, l, o] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, 6), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, 5); +---- +[1, 2, 3, 4, 5] [h, e, l, l, o] + # array_slice scalar function #3 (with positive indexes; first index = second index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 4, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 3); ---- [4] [l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 4, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, 3); +---- +[4] [l] + # array_slice scalar function #4 (with positive indexes; first index > second_index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 2, 1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 4, 1); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 4, 1); +---- +[] [] + # array_slice scalar function #5 (with positive indexes; out of bounds) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 2, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 7); ---- [2, 3, 4, 5] [l, l, o] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 6), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, 7); +---- +[2, 3, 4, 5] [l, l, o] + # array_slice scalar function #6 (with positive indexes; nested array) query ? select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1, 1); ---- [[1, 2, 3, 4, 5]] -# array_slice scalar function #7 (with zero and positive number) -query ?? -select array_slice(make_array(1, 2, 3, 4, 5), 0, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 3); +query ? +select array_slice(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), 1, 1); +---- +[[1, 2, 3, 4, 5]] + +# array_slice scalar function #7 (with zero and positive number) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 0, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 3); +---- +[1, 2, 3, 4] [h, e, l] + +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, 3); ---- [1, 2, 3, 4] [h, e, l] # array_slice scalar function #8 (with NULL and positive number) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL, 3); + # array_slice scalar function #9 (with positive number and NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, NULL); + # array_slice scalar function #10 (with zero-zero) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, 0), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 0); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, 0), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, 0); +---- +[] [] + # array_slice scalar function #11 (with NULL-NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL); + + # array_slice scalar function #12 (with zero and negative number) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, -3); ---- [1] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, -3); +---- +[1] [h, e] + # array_slice scalar function #13 (with negative number and NULL) -query error -select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); +query error +select array_slice(make_array(1, 2, 3, 4, 5), -2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, NULL); + +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2, NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, NULL); # array_slice scalar function #14 (with NULL and negative number) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL, -3); + # array_slice scalar function #15 (with negative indexes) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -1); ---- [2, 3, 4] [l, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -1); +---- +[2, 3, 4] [l, l] + # array_slice scalar function #16 (with negative indexes; almost full array (only with negative indices cannot return full array)) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -5, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -5, -1); ---- [1, 2, 3, 4] [h, e, l, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -5, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -5, -1); +---- +[1, 2, 3, 4] [h, e, l, l] + # array_slice scalar function #17 (with negative indexes; first index = second index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -3); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -3); +---- +[] [] + # array_slice scalar function #18 (with negative indexes; first index > second_index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -6); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -6), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -6); +---- +[] [] + # array_slice scalar function #19 (with negative indexes; out of bounds) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -7, -2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -7, -3); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -7, -2), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -7, -3); +---- +[] [] + # array_slice scalar function #20 (with negative indexes; nested array) -query ? -select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), -2, -1); +query ?? +select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), -2, -1), array_slice(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), -1, -1); ---- -[[1, 2, 3, 4, 5]] +[[1, 2, 3, 4, 5]] [] + +query ?? +select array_slice(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), -2, -1), array_slice(arrow_cast(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), 'LargeList(List(Int64))'), -1, -1); +---- +[[1, 2, 3, 4, 5]] [] + # array_slice scalar function #21 (with first positive index and last negative index) query ?? @@ -877,18 +1416,33 @@ select array_slice(make_array(1, 2, 3, 4, 5), 2, -3), array_slice(make_array('h' ---- [2] [e, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, -3), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 2, -2); +---- +[2] [e, l] + # array_slice scalar function #22 (with first negative index and last positive index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -2, 5), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, 4); ---- [4, 5] [l, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2, 5), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, 4); +---- +[4, 5] [l, l] + # list_slice scalar function #23 (function alias `array_slice`) query ?? select list_slice(make_array(1, 2, 3, 4, 5), 2, 4), list_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 2); ---- [2, 3, 4] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 1, 2); +---- +[2, 3, 4] [h, e] + # array_slice with columns query ? select array_slice(column1, column2, column3) from slices; @@ -901,6 +1455,17 @@ select array_slice(column1, column2, column3) from slices; [41, 42, 43, 44, 45, 46] [55, 56, 57, 58, 59, 60] +query ? +select array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, column3) from slices; +---- +[] +[12, 13, 14, 15, 16] +[] +[] +[] +[41, 42, 43, 44, 45, 46] +[55, 56, 57, 58, 59, 60] + # TODO: support NULLS in output instead of `[]` # array_slice with columns and scalars query ??? @@ -914,6 +1479,17 @@ select array_slice(make_array(1, 2, 3, 4, 5), column2, column3), array_slice(col [1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] [5] [, 54, 55, 56, 57, 58, 59, 60] [55] +query ??? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), 3, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, 5) from slices; +---- +[1] [] [, 2, 3, 4, 5] +[] [13, 14, 15, 16] [12, 13, 14, 15] +[] [] [21, 22, 23, , 25] +[] [33] [] +[4, 5] [] [] +[1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] +[5] [, 54, 55, 56, 57, 58, 59, 60] [55] + # make_array with nulls query ??????? select make_array(make_array('a','b'), null), @@ -939,20 +1515,102 @@ select make_array(['a','b'], null); ---- [[a, b], ] +## array_sort (aliases: `list_sort`) +query ??? +select array_sort(make_array(1, 3, null, 5, NULL, -5)), array_sort(make_array(1, 3, null, 2), 'ASC'), array_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST'); +---- +[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1] + +query ? +select array_sort(column1, 'DESC', 'NULLS LAST') from arrays_values; +---- +[10, 9, 8, 7, 6, 5, 4, 3, 2, ] +[20, 18, 17, 16, 15, 14, 13, 12, 11, ] +[30, 29, 28, 27, 26, 25, 23, 22, 21, ] +[40, 39, 38, 37, 35, 34, 33, 32, 31, ] +NULL +[50, 49, 48, 47, 46, 45, 44, 43, 42, 41] +[60, 59, 58, 57, 56, 55, 54, 52, 51, ] +[70, 69, 68, 67, 66, 65, 64, 63, 62, 61] + +query ? +select array_sort(column1, 'ASC', 'NULLS FIRST') from arrays_values; +---- +[, 2, 3, 4, 5, 6, 7, 8, 9, 10] +[, 11, 12, 13, 14, 15, 16, 17, 18, 20] +[, 21, 22, 23, 25, 26, 27, 28, 29, 30] +[, 31, 32, 33, 34, 35, 37, 38, 39, 40] +NULL +[41, 42, 43, 44, 45, 46, 47, 48, 49, 50] +[, 51, 52, 54, 55, 56, 57, 58, 59, 60] +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70] + + +## list_sort (aliases: `array_sort`) +query ??? +select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3, null, 2), 'ASC'), list_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST'); +---- +[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1] + + ## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`) -# TODO: array_append with NULLs -# array_append scalar function #1 -# query ? -# select array_append(make_array(), 4); +# array_append with NULLs + +query error +select array_append(null, 1); + +query error +select array_append(null, [2, 3]); + +query error +select array_append(null, [[4]]); + +query ???? +select + array_append(make_array(), 4), + array_append(make_array(), null), + array_append(make_array(1, null, 3), 4), + array_append(make_array(null, null), 1) +; +---- +[4] [] [1, , 3, 4] [, , 1] + +# TODO: add this when #8305 is fixed +# query ???? +# select +# array_append(arrow_cast(make_array(), 'LargeList(Null)'), 4), +# array_append(make_array(), null), +# array_append(make_array(1, null, 3), 4), +# array_append(make_array(null, null), 1) +# ; # ---- -# [4] +# [4] [] [1, , 3, 4] [, , 1] + +# test invalid (non-null) +query error +select array_append(1, 2); + +query error +select array_append(1, [2]); -# array_append scalar function #2 +query error +select array_append([1], [2]); + +query ?? +select + array_append(make_array(make_array(1, null, 3)), make_array(null)), + array_append(make_array(make_array(1, null, 3)), null); +---- +[[1, , 3], []] [[1, , 3], ] + +# TODO: add this when #8305 is fixed # query ?? -# select array_append(make_array(), make_array()), array_append(make_array(), make_array(4)); +# select +# array_append(arrow_cast(make_array(make_array(1, null, 3), 'LargeList(LargeList(Int64))')), arrow_cast(make_array(null), 'LargeList(Int64)')), +# array_append(arrow_cast(make_array(make_array(1, null, 3), 'LargeList(LargeList(Int64))')), null); # ---- -# [[]] [[4]] +# [[1, , 3], []] [[1, , 3], ] # array_append scalar function #3 query ??? @@ -960,30 +1618,56 @@ select array_append(make_array(1, 2, 3), 4), array_append(make_array(1.0, 2.0, 3 ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_append(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), array_append(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), array_append(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_append scalar function #4 (element is list) query ??? select array_append(make_array([1], [2], [3]), make_array(4)), array_append(make_array([1.0], [2.0], [3.0]), make_array(4.0)), array_append(make_array(['h'], ['e'], ['l'], ['l']), make_array('o')); ---- [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] +# TODO: add this when #8305 is fixed +# query ??? +# select array_append(arrow_cast(make_array([1], [2], [3]), 'LargeList(LargeList(Int64))'), arrow_cast(make_array(4), 'LargeList(Int64)')), array_append(arrow_cast(make_array([1.0], [2.0], [3.0]), 'LargeList(LargeList(Float64))'), arrow_cast(make_array(4.0), 'LargeList(Float64)')), array_append(arrow_cast(make_array(['h'], ['e'], ['l'], ['l']), 'LargeList(LargeList(Utf8))'), arrow_cast(make_array('o'), 'LargeList(Utf8)')); +# ---- +# [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + # list_append scalar function #5 (function alias `array_append`) query ??? select list_append(make_array(1, 2, 3), 4), list_append(make_array(1.0, 2.0, 3.0), 4.0), list_append(make_array('h', 'e', 'l', 'l'), 'o'); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_append(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), list_append(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), list_append(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_push_back scalar function #6 (function alias `array_append`) query ??? select array_push_back(make_array(1, 2, 3), 4), array_push_back(make_array(1.0, 2.0, 3.0), 4.0), array_push_back(make_array('h', 'e', 'l', 'l'), 'o'); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_push_back(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), array_push_back(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), array_push_back(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # list_push_back scalar function #7 (function alias `array_append`) query ??? select list_push_back(make_array(1, 2, 3), 4), list_push_back(make_array(1.0, 2.0, 3.0), 4.0), list_push_back(make_array('h', 'e', 'l', 'l'), 'o'); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_push_back(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), list_push_back(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), list_push_back(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_append with columns #1 query ? select array_append(column1, column2) from arrays_values; @@ -997,6 +1681,18 @@ select array_append(column1, column2) from arrays_values; [51, 52, , 54, 55, 56, 57, 58, 59, 60, 55] [61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66] +query ? +select array_append(column1, column2) from large_arrays_values; +---- +[, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1] +[11, 12, 13, 14, 15, 16, 17, 18, , 20, 12] +[21, 22, 23, , 25, 26, 27, 28, 29, 30, 23] +[31, 32, 33, 34, 35, , 37, 38, 39, 40, 34] +[44] +[41, 42, 43, 44, 45, 46, 47, 48, 49, 50, ] +[51, 52, , 54, 55, 56, 57, 58, 59, 60, 55] +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66] + # array_append with columns #2 (element is list) query ? select array_append(column1, column2) from nested_arrays; @@ -1004,6 +1700,13 @@ select array_append(column1, column2) from nested_arrays; [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]] [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10, 11, 12]] +# TODO: add this when #8305 is fixed +# query ? +# select array_append(column1, column2) from large_nested_arrays; +# ---- +# [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]] +# [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10, 11, 12]] + # array_append with columns and scalars #1 query ?? select array_append(column2, 100.1), array_append(column3, '.') from arrays; @@ -1016,6 +1719,17 @@ select array_append(column2, 100.1), array_append(column3, '.') from arrays; [100.1] [,, .] [16.6, 17.7, 18.8, 100.1] [.] +query ?? +select array_append(column2, 100.1), array_append(column3, '.') from large_arrays; +---- +[1.1, 2.2, 3.3, 100.1] [L, o, r, e, m, .] +[, 5.5, 6.6, 100.1] [i, p, , u, m, .] +[7.7, 8.8, 9.9, 100.1] [d, , l, o, r, .] +[10.1, , 12.2, 100.1] [s, i, t, .] +[13.3, 14.4, 15.5, 100.1] [a, m, e, t, .] +[100.1] [,, .] +[16.6, 17.7, 18.8, 100.1] [.] + # array_append with columns and scalars #2 query ?? select array_append(column1, make_array(1, 11, 111)), array_append(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), column2) from nested_arrays; @@ -1023,20 +1737,67 @@ select array_append(column1, make_array(1, 11, 111)), array_append(make_array(ma [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]] [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]] +# TODO: add this when #8305 is fixed +# query ?? +# select array_append(column1, arrow_cast(make_array(1, 11, 111), 'LargeList(Int64)')), array_append(arrow_cast(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), 'LargeList(LargeList(Int64))'), column2) from large_nested_arrays; +# ---- +# [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]] +# [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]] + ## array_prepend (aliases: `list_prepend`, `array_push_front`, `list_push_front`) -# TODO: array_prepend with NULLs -# array_prepend scalar function #1 -# query ? -# select array_prepend(4, make_array()); -# ---- -# [4] +# array_prepend with NULLs + +# DuckDB: [4] +# ClickHouse: Null +# Since they dont have the same result, we just follow Postgres, return error +query error +select array_prepend(4, NULL); + +query ? +select array_prepend(4, []); +---- +[4] + +query ? +select array_prepend(4, [null]); +---- +[4, ] + +# DuckDB: [null] +# ClickHouse: [null] +query ? +select array_prepend(null, []); +---- +[] + +query ? +select array_prepend(null, [1]); +---- +[, 1] + +query ? +select array_prepend(null, [[1,2,3]]); +---- +[, [1, 2, 3]] + +# DuckDB: [[]] +# ClickHouse: [[]] +# TODO: We may also return [[]] +query error +select array_prepend([], []); + +# DuckDB: [null] +# ClickHouse: [null] +# TODO: We may also return [null] +query error +select array_prepend(null, null); + +query ? +select array_append([], null); +---- +[] -# array_prepend scalar function #2 -# query ?? -# select array_prepend(make_array(), make_array()), array_prepend(make_array(4), make_array()); -# ---- -# [[]] [[4]] # array_prepend scalar function #3 query ??? @@ -1044,30 +1805,56 @@ select array_prepend(1, make_array(2, 3, 4)), array_prepend(1.0, make_array(2.0, ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_prepend(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), array_prepend(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), array_prepend('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_prepend scalar function #4 (element is list) query ??? select array_prepend(make_array(1), make_array(make_array(2), make_array(3), make_array(4))), array_prepend(make_array(1.0), make_array([2.0], [3.0], [4.0])), array_prepend(make_array('h'), make_array(['e'], ['l'], ['l'], ['o'])); ---- [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] +# TODO: add this when #8305 is fixed +# query ??? +# select array_prepend(arrow_cast(make_array(1), 'LargeList(Int64)'), arrow_cast(make_array(make_array(2), make_array(3), make_array(4)), 'LargeList(LargeList(Int64))')), array_prepend(arrow_cast(make_array(1.0), 'LargeList(Float64)'), arrow_cast(make_array([2.0], [3.0], [4.0]), 'LargeList(LargeList(Float64))')), array_prepend(arrow_cast(make_array('h'), 'LargeList(Utf8)'), arrow_cast(make_array(['e'], ['l'], ['l'], ['o']), 'LargeList(LargeList(Utf8))'')); +# ---- +# [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + # list_prepend scalar function #5 (function alias `array_prepend`) query ??? select list_prepend(1, make_array(2, 3, 4)), list_prepend(1.0, make_array(2.0, 3.0, 4.0)), list_prepend('h', make_array('e', 'l', 'l', 'o')); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_prepend(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), list_prepend(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), list_prepend('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_push_front scalar function #6 (function alias `array_prepend`) query ??? select array_push_front(1, make_array(2, 3, 4)), array_push_front(1.0, make_array(2.0, 3.0, 4.0)), array_push_front('h', make_array('e', 'l', 'l', 'o')); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_push_front(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), array_push_front(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), array_push_front('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # list_push_front scalar function #7 (function alias `array_prepend`) query ??? select list_push_front(1, make_array(2, 3, 4)), list_push_front(1.0, make_array(2.0, 3.0, 4.0)), list_push_front('h', make_array('e', 'l', 'l', 'o')); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_push_front(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), list_push_front(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), list_push_front('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_prepend with columns #1 query ? select array_prepend(column2, column1) from arrays_values; @@ -1081,6 +1868,18 @@ select array_prepend(column2, column1) from arrays_values; [55, 51, 52, , 54, 55, 56, 57, 58, 59, 60] [66, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70] +query ? +select array_prepend(column2, column1) from large_arrays_values; +---- +[1, , 2, 3, 4, 5, 6, 7, 8, 9, 10] +[12, 11, 12, 13, 14, 15, 16, 17, 18, , 20] +[23, 21, 22, 23, , 25, 26, 27, 28, 29, 30] +[34, 31, 32, 33, 34, 35, , 37, 38, 39, 40] +[44] +[, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50] +[55, 51, 52, , 54, 55, 56, 57, 58, 59, 60] +[66, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70] + # array_prepend with columns #2 (element is list) query ? select array_prepend(column2, column1) from nested_arrays; @@ -1088,6 +1887,13 @@ select array_prepend(column2, column1) from nested_arrays; [[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] +# TODO: add this when #8305 is fixed +# query ? +# select array_prepend(column2, column1) from large_nested_arrays; +# ---- +# [[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] +# [[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] + # array_prepend with columns and scalars #1 query ?? select array_prepend(100.1, column2), array_prepend('.', column3) from arrays; @@ -1100,6 +1906,17 @@ select array_prepend(100.1, column2), array_prepend('.', column3) from arrays; [100.1] [., ,] [100.1, 16.6, 17.7, 18.8] [.] +query ?? +select array_prepend(100.1, column2), array_prepend('.', column3) from large_arrays; +---- +[100.1, 1.1, 2.2, 3.3] [., L, o, r, e, m] +[100.1, , 5.5, 6.6] [., i, p, , u, m] +[100.1, 7.7, 8.8, 9.9] [., d, , l, o, r] +[100.1, 10.1, , 12.2] [., s, i, t] +[100.1, 13.3, 14.4, 15.5] [., a, m, e, t] +[100.1] [., ,] +[100.1, 16.6, 17.7, 18.8] [.] + # array_prepend with columns and scalars #2 (element is list) query ?? select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, make_array(make_array(1, 2, 3), make_array(11, 12, 13))) from nested_arrays; @@ -1107,71 +1924,103 @@ select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, ma [[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] [[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] +# TODO: add this when #8305 is fixed +# query ?? +# select array_prepend(arrow_cast(make_array(1, 11, 111), 'LargeList(Int64)'), column1), array_prepend(column2, arrow_cast(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), 'LargeList(LargeList(Int64))')) from large_nested_arrays; +# ---- +# [[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] +# [[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] + ## array_repeat (aliases: `list_repeat`) # array_repeat scalar function #1 -query ??? -select array_repeat(1, 5), array_repeat(3.14, 3), array_repeat('l', 4); ----- -[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] +query ???????? +select + array_repeat(1, 5), + array_repeat(3.14, 3), + array_repeat('l', 4), + array_repeat(null, 2), + list_repeat(-1, 5), + list_repeat(-3.14, 0), + list_repeat('rust', 4), + list_repeat(null, 0); +---- +[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] [, ] [-1, -1, -1, -1, -1] [] [rust, rust, rust, rust] [] # array_repeat scalar function #2 (element as list) -query ??? -select array_repeat([1], 5), array_repeat([1.1, 2.2, 3.3], 3), array_repeat([[1, 2], [3, 4]], 2); +query ???? +select + array_repeat([1], 5), + array_repeat([1.1, 2.2, 3.3], 3), + array_repeat([null, null], 3), + array_repeat([[1, 2], [3, 4]], 2); ---- -[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] +[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[, ], [, ], [, ]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] -# list_repeat scalar function #3 (function alias: `array_repeat`) -query ??? -select list_repeat(1, 5), list_repeat(3.14, 3), list_repeat('l', 4); +query ???? +select + array_repeat(arrow_cast([1], 'LargeList(Int64)'), 5), + array_repeat(arrow_cast([1.1, 2.2, 3.3], 'LargeList(Float64)'), 3), + array_repeat(arrow_cast([null, null], 'LargeList(Null)'), 3), + array_repeat(arrow_cast([[1, 2], [3, 4]], 'LargeList(List(Int64))'), 2); ---- -[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] +[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[, ], [, ], [, ]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] # array_repeat with columns #1 -query ? -select array_repeat(column4, column1) from values_without_nulls; ----- -[1.1] -[2.2, 2.2] -[3.3, 3.3, 3.3] -[4.4, 4.4, 4.4, 4.4] -[5.5, 5.5, 5.5, 5.5, 5.5] -[6.6, 6.6, 6.6, 6.6, 6.6, 6.6] -[7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 7.7] -[8.8, 8.8, 8.8, 8.8, 8.8, 8.8, 8.8, 8.8] -[9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9] -# array_repeat with columns #2 (element as list) -query ? -select array_repeat(column1, column3) from arrays_values_without_nulls; ----- -[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] -[[11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]] -[[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]] -[[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40]] +statement ok +CREATE TABLE array_repeat_table +AS VALUES + (1, 1, 1.1, 'a', make_array(4, 5, 6)), + (2, null, null, null, null), + (3, 2, 2.2, 'rust', make_array(7)), + (0, 3, 3.3, 'datafusion', make_array(8, 9)); -# array_repeat with columns and scalars #1 -query ?? -select array_repeat(1, column1), array_repeat(column4, 3) from values_without_nulls; ----- -[1] [1.1, 1.1, 1.1] -[1, 1] [2.2, 2.2, 2.2] -[1, 1, 1] [3.3, 3.3, 3.3] -[1, 1, 1, 1] [4.4, 4.4, 4.4] -[1, 1, 1, 1, 1] [5.5, 5.5, 5.5] -[1, 1, 1, 1, 1, 1] [6.6, 6.6, 6.6] -[1, 1, 1, 1, 1, 1, 1] [7.7, 7.7, 7.7] -[1, 1, 1, 1, 1, 1, 1, 1] [8.8, 8.8, 8.8] -[1, 1, 1, 1, 1, 1, 1, 1, 1] [9.9, 9.9, 9.9] +statement ok +CREATE TABLE large_array_repeat_table +AS SELECT + column1, + column2, + column3, + column4, + arrow_cast(column5, 'LargeList(Int64)') as column5 +FROM array_repeat_table; + +query ?????? +select + array_repeat(column2, column1), + array_repeat(column3, column1), + array_repeat(column4, column1), + array_repeat(column5, column1), + array_repeat(column2, 3), + array_repeat(make_array(1), column1) +from array_repeat_table; +---- +[1] [1.1] [a] [[4, 5, 6]] [1, 1, 1] [[1]] +[, ] [, ] [, ] [, ] [, , ] [[1], [1]] +[2, 2, 2] [2.2, 2.2, 2.2] [rust, rust, rust] [[7], [7], [7]] [2, 2, 2] [[1], [1], [1]] +[] [] [] [] [3, 3, 3] [] + +query ?????? +select + array_repeat(column2, column1), + array_repeat(column3, column1), + array_repeat(column4, column1), + array_repeat(column5, column1), + array_repeat(column2, 3), + array_repeat(make_array(1), column1) +from large_array_repeat_table; +---- +[1] [1.1] [a] [[4, 5, 6]] [1, 1, 1] [[1]] +[, ] [, ] [, ] [, ] [, , ] [[1], [1]] +[2, 2, 2] [2.2, 2.2, 2.2] [rust, rust, rust] [[7], [7], [7]] [2, 2, 2] [[1], [1], [1]] +[] [] [] [] [3, 3, 3] [] -# array_repeat with columns and scalars #2 (element as list) -query ?? -select array_repeat([1], column3), array_repeat(column1, 3) from arrays_values_without_nulls; ----- -[[1]] [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] -[[1], [1]] [[11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]] -[[1], [1], [1]] [[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]] -[[1], [1], [1], [1]] [[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40]] +statement ok +drop table array_repeat_table; + +statement ok +drop table large_array_repeat_table; ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) @@ -1433,9 +2282,19 @@ select array_position(['h', 'e', 'l', 'l', 'o'], 'l'), array_position([1, 2, 3, ---- 3 5 1 -# array_position scalar function #2 (with optional argument) query III -select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1, 2, 5, 4, 5], 5, 4), array_position([1, 1, 1], 1, 2); +select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + +# array_position scalar function #2 (with optional argument) +query III +select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1, 2, 5, 4, 5], 5, 4), array_position([1, 1, 1], 1, 2); +---- +4 5 2 + +query III +select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l', 4), array_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5, 4), array_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1, 2); ---- 4 5 2 @@ -1451,24 +2310,44 @@ select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, ---- 4 3 +query II +select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), [4, 5, 6]), array_position(arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), [2, 3, 4]); +---- +2 2 + # list_position scalar function #5 (function alias `array_position`) query III select list_position(['h', 'e', 'l', 'l', 'o'], 'l'), list_position([1, 2, 3, 4, 5], 5), list_position([1, 1, 1], 1); ---- 3 5 1 +query III +select list_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + # array_indexof scalar function #6 (function alias `array_position`) query III select array_indexof(['h', 'e', 'l', 'l', 'o'], 'l'), array_indexof([1, 2, 3, 4, 5], 5), array_indexof([1, 1, 1], 1); ---- 3 5 1 +query III +select array_indexof(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_indexof(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_indexof(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + # list_indexof scalar function #7 (function alias `array_position`) query III select list_indexof(['h', 'e', 'l', 'l', 'o'], 'l'), list_indexof([1, 2, 3, 4, 5], 5), list_indexof([1, 1, 1], 1); ---- 3 5 1 +query III +select list_indexof(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_indexof(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_indexof(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + # array_position with columns #1 query II select array_position(column1, column2), array_position(column1, column2, column3) from arrays_values_without_nulls; @@ -1478,6 +2357,14 @@ select array_position(column1, column2), array_position(column1, column2, column 3 3 4 4 +query II +select array_position(column1, column2), array_position(column1, column2, column3) from large_arrays_values_without_nulls; +---- +1 1 +2 2 +3 3 +4 4 + # array_position with columns #2 (element is list) query II select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays; @@ -1485,6 +2372,13 @@ select array_position(column1, column2), array_position(column1, column2, column 3 3 2 5 +#TODO: add this test when #8305 is fixed +#query II +#select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays; +#---- +#3 3 +#2 5 + # array_position with columns and scalars #1 query III select array_position(make_array(1, 2, 3, 4, 5), column2), array_position(column1, 3), array_position(column1, 3, 5) from arrays_values_without_nulls; @@ -1494,6 +2388,14 @@ NULL NULL NULL NULL NULL NULL NULL NULL NULL +query III +select array_position(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2), array_position(column1, 3), array_position(column1, 3, 5) from large_arrays_values_without_nulls; +---- +1 3 NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL + # array_position with columns and scalars #2 (element is list) query III select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from nested_arrays; @@ -1501,6 +2403,13 @@ select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), column2), NULL 6 4 NULL 1 NULL +#TODO: add this test when #8305 is fixed +#query III +#select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), 'LargeList(List(Int64))'), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from large_nested_arrays; +#---- +#NULL 6 4 +#NULL 1 NULL + ## array_positions (aliases: `list_positions`) # array_positions scalar function #1 @@ -1509,18 +2418,33 @@ select array_positions(['h', 'e', 'l', 'l', 'o'], 'l'), array_positions([1, 2, 3 ---- [3, 4] [5] [1, 2, 3] +query ??? +select array_positions(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_positions(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_positions(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +[3, 4] [5] [1, 2, 3] + # array_positions scalar function #2 (element is list) query ? select array_positions(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2, 1, 3], [4, 5, 6]), [2, 1, 3]); ---- [2, 4] +query ? +select array_positions(arrow_cast(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2, 1, 3], [4, 5, 6]), 'LargeList(List(Int64))'), [2, 1, 3]); +---- +[2, 4] + # list_positions scalar function #3 (function alias `array_positions`) query ??? select list_positions(['h', 'e', 'l', 'l', 'o'], 'l'), list_positions([1, 2, 3, 4, 5], 5), list_positions([1, 1, 1], 1); ---- [3, 4] [5] [1, 2, 3] +query ??? +select list_positions(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_positions(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_positions(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +[3, 4] [5] [1, 2, 3] + # array_positions with columns #1 query ? select array_positions(column1, column2) from arrays_values_without_nulls; @@ -1530,6 +2454,14 @@ select array_positions(column1, column2) from arrays_values_without_nulls; [3] [4] +query ? +select array_positions(arrow_cast(column1, 'LargeList(Int64)'), column2) from arrays_values_without_nulls; +---- +[1] +[2] +[3] +[4] + # array_positions with columns #2 (element is list) query ? select array_positions(column1, column2) from nested_arrays; @@ -1537,6 +2469,12 @@ select array_positions(column1, column2) from nested_arrays; [3] [2, 5] +query ? +select array_positions(arrow_cast(column1, 'LargeList(List(Int64))'), column2) from nested_arrays; +---- +[3] +[2, 5] + # array_positions with columns and scalars #1 query ?? select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 45], column2) from arrays_values_without_nulls; @@ -1546,6 +2484,14 @@ select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 45], [] [3] [] [] +query ?? +select array_positions(arrow_cast(column1, 'LargeList(Int64)'), 4), array_positions(array[1, 2, 23, 13, 33, 45], column2) from arrays_values_without_nulls; +---- +[4] [1] +[] [] +[] [3] +[] [] + # array_positions with columns and scalars #2 (element is list) query ?? select array_positions(column1, make_array(4, 5, 6)), array_positions(make_array([1, 2, 3], [11, 12, 13], [4, 5, 6]), column2) from nested_arrays; @@ -1553,23 +2499,76 @@ select array_positions(column1, make_array(4, 5, 6)), array_positions(make_array [6] [] [1] [] +query ?? +select array_positions(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(4, 5, 6)), array_positions(arrow_cast(make_array([1, 2, 3], [11, 12, 13], [4, 5, 6]), 'LargeList(List(Int64))'), column2) from nested_arrays; +---- +[6] [] +[1] [] + ## array_replace (aliases: `list_replace`) # array_replace scalar function #1 query ??? -select array_replace(make_array(1, 2, 3, 4), 2, 3), array_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), array_replace(make_array(1, 2, 3), 4, 0); +select + array_replace(make_array(1, 2, 3, 4), 2, 3), + array_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + array_replace(make_array(1, 2, 3), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] + +query ??? +select + array_replace(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + array_replace(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + array_replace(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); ---- [1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] # array_replace scalar function #2 (element is list) query ?? -select array_replace(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6], [1, 1, 1]), array_replace(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], [3, 1, 4]); +select + array_replace( + make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), + [4, 5, 6], + [1, 1, 1] + ), + array_replace( + make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), + [2, 3, 4], + [3, 1, 4] + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]] + +query ?? +select + array_replace( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), + [4, 5, 6], + [1, 1, 1] + ), + array_replace( + arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), + [2, 3, 4], + [3, 1, 4] + ); ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]] # list_replace scalar function #3 (function alias `list_replace`) query ??? -select list_replace(make_array(1, 2, 3, 4), 2, 3), list_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), list_replace(make_array(1, 2, 3), 4, 0); +select list_replace( + make_array(1, 2, 3, 4), 2, 3), + list_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + list_replace(make_array(1, 2, 3), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] + +query ??? +select list_replace( + arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + list_replace(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + list_replace(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); ---- [1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] @@ -1582,6 +2581,14 @@ select array_replace(column1, column2, column3) from arrays_with_repeating_eleme [10, 7, 7, 8, 7, 9, 7, 8, 7, 7] [13, 11, 12, 10, 11, 12, 10, 11, 12, 10] +query ? +select array_replace(column1, column2, column3) from large_arrays_with_repeating_elements; +---- +[1, 4, 1, 3, 2, 2, 1, 3, 2, 3] +[7, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[10, 7, 7, 8, 7, 9, 7, 8, 7, 7] +[13, 11, 12, 10, 11, 12, 10, 11, 12, 10] + # array_replace scalar function with columns #2 (element is list) query ? select array_replace(column1, column2, column3) from nested_arrays_with_repeating_elements; @@ -1591,9 +2598,33 @@ select array_replace(column1, column2, column3) from nested_arrays_with_repeatin [[28, 29, 30], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] +query ? +select array_replace(column1, column2, column3) from large_nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[19, 20, 21], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[28, 29, 30], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + # array_replace scalar function with columns and scalars #1 query ??? -select array_replace(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3), array_replace(column1, 1, column3), array_replace(column1, column2, 4) from arrays_with_repeating_elements; +select + array_replace(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3), + array_replace(column1, 1, column3), + array_replace(column1, column2, 4) +from arrays_with_repeating_elements; +---- +[1, 4, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 1, 3, 2, 2, 1, 3, 2, 3] [1, 4, 1, 3, 2, 2, 1, 3, 2, 3] +[1, 2, 2, 7, 5, 4, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 7, 10, 7, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 7, 7, 8, 7, 9, 7, 8, 7, 7] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 10, 11, 12, 10, 11, 12, 10] + +query ??? +select + array_replace(arrow_cast(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), 'LargeList(Int64)'), column2, column3), + array_replace(column1, 1, column3), + array_replace(column1, column2, 4) +from large_arrays_with_repeating_elements; ---- [1, 4, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 1, 3, 2, 2, 1, 3, 2, 3] [1, 4, 1, 3, 2, 2, 1, 3, 2, 3] [1, 2, 2, 7, 5, 4, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] @@ -1602,7 +2633,33 @@ select array_replace(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, c # array_replace scalar function with columns and scalars #2 (element is list) query ??? -select array_replace(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), column2, column3), array_replace(column1, make_array(1, 2, 3), column3), array_replace(column1, column2, make_array(11, 12, 13)) from nested_arrays_with_repeating_elements; +select + array_replace( + make_array( + [1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), + column2, + column3 + ), + array_replace(column1, make_array(1, 2, 3), column3), + array_replace(column1, column2, make_array(11, 12, 13)) +from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + +query ??? +select + array_replace( + arrow_cast(make_array( + [1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]),'LargeList(List(Int64))'), + column2, + column3 + ), + array_replace(column1, make_array(1, 2, 3), column3), + array_replace(column1, column2, make_array(11, 12, 13)) +from large_nested_arrays_with_repeating_elements; ---- [[1, 2, 3], [10, 11, 12], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] @@ -1613,25 +2670,88 @@ select array_replace(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [ # array_replace_n scalar function #1 query ??? -select array_replace_n(make_array(1, 2, 3, 4), 2, 3, 2), array_replace_n(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0, 2), array_replace_n(make_array(1, 2, 3), 4, 0, 3); +select + array_replace_n(make_array(1, 2, 3, 4), 2, 3, 2), + array_replace_n(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0, 2), + array_replace_n(make_array(1, 2, 3), 4, 0, 3); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] + +query ??? +select + array_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3, 2), + array_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0, 2), + array_replace_n(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0, 3); ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] # array_replace_n scalar function #2 (element is list) query ?? -select array_replace_n(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6], [1, 1, 1], 2), array_replace_n(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], [3, 1, 4], 2); +select + array_replace_n( + make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), + [4, 5, 6], + [1, 1, 1], + 2 + ), + array_replace_n( + make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), + [2, 3, 4], + [3, 1, 4], + 2 + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] + +query ?? +select + array_replace_n( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), + [4, 5, 6], + [1, 1, 1], + 2 + ), + array_replace_n( + arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), + [2, 3, 4], + [3, 1, 4], + 2 + ); ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] # list_replace_n scalar function #3 (function alias `array_replace_n`) query ??? -select list_replace_n(make_array(1, 2, 3, 4), 2, 3, 2), list_replace_n(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0, 2), list_replace_n(make_array(1, 2, 3), 4, 0, 3); +select + list_replace_n(make_array(1, 2, 3, 4), 2, 3, 2), + list_replace_n(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0, 2), + list_replace_n(make_array(1, 2, 3), 4, 0, 3); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] + +query ??? +select + list_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3, 2), + list_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0, 2), + list_replace_n(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0, 3); ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] # array_replace_n scalar function with columns #1 query ? -select array_replace_n(column1, column2, column3, column4) from arrays_with_repeating_elements; +select + array_replace_n(column1, column2, column3, column4) +from arrays_with_repeating_elements; +---- +[1, 4, 1, 3, 4, 4, 1, 3, 2, 3] +[7, 7, 5, 5, 6, 5, 5, 5, 4, 4] +[10, 10, 10, 8, 10, 9, 10, 8, 7, 7] +[13, 11, 12, 13, 11, 12, 13, 11, 12, 13] + +query ? +select + array_replace_n(column1, column2, column3, column4) +from large_arrays_with_repeating_elements; ---- [1, 4, 1, 3, 4, 4, 1, 3, 2, 3] [7, 7, 5, 5, 6, 5, 5, 5, 4, 4] @@ -1640,16 +2760,47 @@ select array_replace_n(column1, column2, column3, column4) from arrays_with_repe # array_replace_n scalar function with columns #2 (element is list) query ? -select array_replace_n(column1, column2, column3, column4) from nested_arrays_with_repeating_elements; +select + array_replace_n(column1, column2, column3, column4) +from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] + +query ? +select + array_replace_n(column1, column2, column3, column4) +from large_nested_arrays_with_repeating_elements; ---- [[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] + # array_replace_n scalar function with columns and scalars #1 query ???? -select array_replace_n(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3, column4), array_replace_n(column1, 1, column3, column4), array_replace_n(column1, column2, 4, column4), array_replace_n(column1, column2, column3, 2) from arrays_with_repeating_elements; +select + array_replace_n(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3, column4), + array_replace_n(column1, 1, column3, column4), + array_replace_n(column1, column2, 4, column4), + array_replace_n(column1, column2, column3, 2) +from arrays_with_repeating_elements; +---- +[1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 2, 3] [1, 4, 1, 3, 4, 2, 1, 3, 2, 3] +[1, 2, 2, 7, 5, 7, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [7, 7, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 10, 10, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 4, 4, 8, 4, 9, 4, 8, 7, 7] [10, 10, 7, 8, 7, 9, 7, 8, 7, 7] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 4, 11, 12, 4, 11, 12, 4] [13, 11, 12, 13, 11, 12, 10, 11, 12, 10] + +query ???? +select + array_replace_n(arrow_cast(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), 'LargeList(Int64)'), column2, column3, column4), + array_replace_n(column1, 1, column3, column4), + array_replace_n(column1, column2, 4, column4), + array_replace_n(column1, column2, column3, 2) +from large_arrays_with_repeating_elements; ---- [1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 2, 3] [1, 4, 1, 3, 4, 2, 1, 3, 2, 3] [1, 2, 2, 7, 5, 7, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [7, 7, 5, 5, 6, 5, 5, 5, 4, 4] @@ -1658,7 +2809,37 @@ select array_replace_n(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, # array_replace_n scalar function with columns and scalars #2 (element is list) query ???? -select array_replace_n(make_array([7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]), column2, column3, column4), array_replace_n(column1, make_array(1, 2, 3), column3, column4), array_replace_n(column1, column2, make_array(11, 12, 13), column4), array_replace_n(column1, column2, column3, 2) from nested_arrays_with_repeating_elements; +select + array_replace_n( + make_array( + [7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]), + column2, + column3, + column4 + ), + array_replace_n(column1, make_array(1, 2, 3), column3, column4), + array_replace_n(column1, column2, make_array(11, 12, 13), column4), + array_replace_n(column1, column2, column3, 2) +from nested_arrays_with_repeating_elements; +---- +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [10, 11, 12]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [19, 20, 21], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[28, 29, 30], [28, 29, 30], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + +query ???? +select + array_replace_n( + arrow_cast(make_array( + [7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]), 'LargeList(List(Int64))'), + column2, + column3, + column4 + ), + array_replace_n(column1, make_array(1, 2, 3), column3, column4), + array_replace_n(column1, column2, make_array(11, 12, 13), column4), + array_replace_n(column1, column2, column3, 2) +from large_nested_arrays_with_repeating_elements; ---- [[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [10, 11, 12]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[7, 8, 9], [2, 1, 3], [1, 5, 6], [19, 20, 21], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] @@ -1669,25 +2850,84 @@ select array_replace_n(make_array([7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], # array_replace_all scalar function #1 query ??? -select array_replace_all(make_array(1, 2, 3, 4), 2, 3), array_replace_all(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), array_replace_all(make_array(1, 2, 3), 4, 0); +select + array_replace_all(make_array(1, 2, 3, 4), 2, 3), + array_replace_all(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + array_replace_all(make_array(1, 2, 3), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] + +query ??? +select + array_replace_all(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + array_replace_all(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + array_replace_all(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); ---- [1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] # array_replace_all scalar function #2 (element is list) query ?? -select array_replace_all(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6], [1, 1, 1]), array_replace_all(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], [3, 1, 4]); +select + array_replace_all( + make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), + [4, 5, 6], + [1, 1, 1] + ), + array_replace_all( + make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), + [2, 3, 4], + [3, 1, 4] + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] + +query ?? +select + array_replace_all( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), + [4, 5, 6], + [1, 1, 1] + ), + array_replace_all( + arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), + [2, 3, 4], + [3, 1, 4] + ); ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] # list_replace_all scalar function #3 (function alias `array_replace_all`) query ??? -select list_replace_all(make_array(1, 2, 3, 4), 2, 3), list_replace_all(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), list_replace_all(make_array(1, 2, 3), 4, 0); +select + list_replace_all(make_array(1, 2, 3, 4), 2, 3), + list_replace_all(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + list_replace_all(make_array(1, 2, 3), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] + +query ??? +select + list_replace_all(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + list_replace_all(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + list_replace_all(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); ---- [1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] # array_replace_all scalar function with columns #1 query ? -select array_replace_all(column1, column2, column3) from arrays_with_repeating_elements; +select + array_replace_all(column1, column2, column3) +from arrays_with_repeating_elements; +---- +[1, 4, 1, 3, 4, 4, 1, 3, 4, 3] +[7, 7, 5, 5, 6, 5, 5, 5, 7, 7] +[10, 10, 10, 8, 10, 9, 10, 8, 10, 10] +[13, 11, 12, 13, 11, 12, 13, 11, 12, 13] + +query ? +select + array_replace_all(column1, column2, column3) +from large_arrays_with_repeating_elements; ---- [1, 4, 1, 3, 4, 4, 1, 3, 4, 3] [7, 7, 5, 5, 6, 5, 5, 5, 7, 7] @@ -1696,7 +2936,19 @@ select array_replace_all(column1, column2, column3) from arrays_with_repeating_e # array_replace_all scalar function with columns #2 (element is list) query ? -select array_replace_all(column1, column2, column3) from nested_arrays_with_repeating_elements; +select + array_replace_all(column1, column2, column3) +from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [7, 8, 9]] +[[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [19, 20, 21], [19, 20, 21]] +[[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [28, 29, 30], [28, 29, 30]] +[[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] + +query ? +select + array_replace_all(column1, column2, column3) +from large_nested_arrays_with_repeating_elements; ---- [[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [7, 8, 9]] [[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [19, 20, 21], [19, 20, 21]] @@ -1705,7 +2957,23 @@ select array_replace_all(column1, column2, column3) from nested_arrays_with_repe # array_replace_all scalar function with columns and scalars #1 query ??? -select array_replace_all(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3), array_replace_all(column1, 1, column3), array_replace_all(column1, column2, 4) from arrays_with_repeating_elements; +select + array_replace_all(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3), + array_replace_all(column1, 1, column3), + array_replace_all(column1, column2, 4) +from arrays_with_repeating_elements; +---- +[1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 4, 3] +[1, 2, 2, 7, 5, 7, 7, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 10, 10, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 4, 4, 8, 4, 9, 4, 8, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 4, 11, 12, 4, 11, 12, 4] + +query ??? +select + array_replace_all(arrow_cast(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), 'LargeList(Int64)'), column2, column3), + array_replace_all(column1, 1, column3), + array_replace_all(column1, column2, 4) +from large_arrays_with_repeating_elements; ---- [1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 4, 3] [1, 2, 2, 7, 5, 7, 7, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] @@ -1714,13 +2982,68 @@ select array_replace_all(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column # array_replace_all scalar function with columns and scalars #2 (element is list) query ??? -select array_replace_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), column2, column3), array_replace_all(column1, make_array(1, 2, 3), column3), array_replace_all(column1, column2, make_array(11, 12, 13)) from nested_arrays_with_repeating_elements; +select + array_replace_all( + make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), + column2, + column3 + ), + array_replace_all(column1, make_array(1, 2, 3), column3), + array_replace_all(column1, column2, make_array(11, 12, 13)) +from nested_arrays_with_repeating_elements; ---- [[1, 2, 3], [10, 11, 12], [10, 11, 12], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [7, 8, 9]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [19, 20, 21], [19, 20, 21], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [11, 12, 13], [11, 12, 13]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [11, 12, 13], [11, 12, 13]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] +query ??? +select + array_replace_all( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), 'LargeList(List(Int64))'), + column2, + column3 + ), + array_replace_all(column1, make_array(1, 2, 3), column3), + array_replace_all(column1, column2, make_array(11, 12, 13)) +from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [10, 11, 12], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [7, 8, 9]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [19, 20, 21], [19, 20, 21], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [11, 12, 13], [11, 12, 13]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [11, 12, 13], [11, 12, 13]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] + +# array_replace with null handling + +statement ok +create table t as values + (make_array(3, 1, NULL, 3), 3, 4, 2), + (make_array(3, 1, NULL, 3), NULL, 5, 2), + (NULL, 3, 2, 1), + (make_array(3, 1, 3), 3, NULL, 1) +; + + +# ([3, 1, NULL, 3], 3, 4, 2) => [4, 1, NULL, 4] NULL not matched +# ([3, 1, NULL, 3], NULL, 5, 2) => [3, 1, NULL, 3] NULL is replaced with 5 +# ([NULL], 3, 2, 1) => NULL +# ([3, 1, 3], 3, NULL, 1) => [NULL, 1 3] + +query ?III? +select column1, column2, column3, column4, array_replace_n(column1, column2, column3, column4) from t; +---- +[3, 1, , 3] 3 4 2 [4, 1, , 4] +[3, 1, , 3] NULL 5 2 [3, 1, 5, 3] +NULL 3 2 1 NULL +[3, 1, 3] 3 NULL 1 [, 1, 3] + + + +statement ok +drop table t; + + + ## array_to_string (aliases: `list_to_string`, `array_join`, `list_join`) # array_to_string scalar function #1 @@ -1741,51 +3064,244 @@ select array_to_string(make_array(), ',') ---- (empty) -# list_to_string scalar function #4 (function alias `array_to_string`) -query TTT -select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, 4, 5], '-'), list_to_string([1.0, 2.0, 3.0], '|'); ----- -h,e,l,l,o 1-2-3-4-5 1|2|3 -# array_join scalar function #5 (function alias `array_to_string`) -query TTT -select array_join(['h', 'e', 'l', 'l', 'o'], ','), array_join([1, 2, 3, 4, 5], '-'), array_join([1.0, 2.0, 3.0], '|'); +## array_union (aliases: `list_union`) + +# array_union scalar function #1 +query ? +select array_union([1, 2, 3, 4], [5, 6, 3, 4]); ---- -h,e,l,l,o 1-2-3-4-5 1|2|3 +[1, 2, 3, 4, 5, 6] -# list_join scalar function #6 (function alias `list_join`) -query TTT -select list_join(['h', 'e', 'l', 'l', 'o'], ','), list_join([1, 2, 3, 4, 5], '-'), list_join([1.0, 2.0, 3.0], '|'); +query ? +select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, 6, 3, 4], 'LargeList(Int64)')); ---- -h,e,l,l,o 1-2-3-4-5 1|2|3 +[1, 2, 3, 4, 5, 6] -# array_to_string scalar function with nulls #1 -query TTT -select array_to_string(make_array('h', NULL, 'l', NULL, 'o'), ','), array_to_string(make_array(1, NULL, 3, NULL, 5), '-'), array_to_string(make_array(NULL, 2.0, 3.0), '|'); +# array_union scalar function #2 +query ? +select array_union([1, 2, 3, 4], [5, 6, 7, 8]); ---- -h,l,o 1-3-5 2|3 +[1, 2, 3, 4, 5, 6, 7, 8] -# array_to_string scalar function with nulls #2 -query TTT -select array_to_string(make_array('h', NULL, NULL, NULL, 'o'), ',', '-'), array_to_string(make_array(NULL, 2, NULL, 4, 5), '-', 'nil'), array_to_string(make_array(1.0, NULL, 3.0), '|', '0'); +query ? +select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, 6, 7, 8], 'LargeList(Int64)')); ---- -h,-,-,-,o nil-2-nil-4-5 1|0|3 +[1, 2, 3, 4, 5, 6, 7, 8] -# array_to_string with columns #1 +# array_union scalar function #3 +query ? +select array_union([1,2,3], []); +---- +[1, 2, 3] -# For reference -# select column1, column4 from arrays_values; -# ---- -# [, 2, 3, 4, 5, 6, 7, 8, 9, 10] , -# [11, 12, 13, 14, 15, 16, 17, 18, , 20] . -# [21, 22, 23, , 25, 26, 27, 28, 29, 30] - -# [31, 32, 33, 34, 35, , 37, 38, 39, 40] ok -# NULL @ -# [41, 42, 43, 44, 45, 46, 47, 48, 49, 50] $ -# [51, 52, , 54, 55, 56, 57, 58, 59, 60] ^ -# [61, 62, 63, 64, 65, 66, 67, 68, 69, 70] NULL +query ? +select array_union(arrow_cast([1,2,3], 'LargeList(Int64)'), arrow_cast([], 'LargeList(Null)')); +---- +[1, 2, 3] -query T +# array_union scalar function #4 +query ? +select array_union([1, 2, 3, 4], [5, 4]); +---- +[1, 2, 3, 4, 5] + +query ? +select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, 4], 'LargeList(Int64)')); +---- +[1, 2, 3, 4, 5] + +# array_union scalar function #5 +statement ok +CREATE TABLE arrays_with_repeating_elements_for_union +AS VALUES + ([1], [2]), + ([2, 3], [3]), + ([3], [3, 4]) +; + +query ? +select array_union(column1, column2) from arrays_with_repeating_elements_for_union; +---- +[1, 2] +[2, 3] +[3, 4] + +query ? +select array_union(arrow_cast(column1, 'LargeList(Int64)'), arrow_cast(column2, 'LargeList(Int64)')) from arrays_with_repeating_elements_for_union; +---- +[1, 2] +[2, 3] +[3, 4] + +statement ok +drop table arrays_with_repeating_elements_for_union; + +# array_union scalar function #6 +query ? +select array_union([], []); +---- +[] + +query ? +select array_union(arrow_cast([], 'LargeList(Null)'), arrow_cast([], 'LargeList(Null)')); +---- +[] + +# array_union scalar function #7 +query ? +select array_union([[null]], []); +---- +[[]] + +query ? +select array_union(arrow_cast([[null]], 'LargeList(List(Null))'), arrow_cast([], 'LargeList(Null)')); +---- +[[]] + +# array_union scalar function #8 +query ? +select array_union([null], [null]); +---- +[] + +query ? +select array_union(arrow_cast([[null]], 'LargeList(List(Null))'), arrow_cast([[null]], 'LargeList(List(Null))')); +---- +[[]] + +# array_union scalar function #9 +query ? +select array_union(null, []); +---- +[] + +query ? +select array_union(null, arrow_cast([], 'LargeList(Null)')); +---- +[] + +# array_union scalar function #10 +query ? +select array_union(null, null); +---- +NULL + +# array_union scalar function #11 +query ? +select array_union([1, 1, 2, 2, 3, 3], null); +---- +[1, 2, 3] + +query ? +select array_union(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null); +---- +[1, 2, 3] + +# array_union scalar function #12 +query ? +select array_union(null, [1, 1, 2, 2, 3, 3]); +---- +[1, 2, 3] + +query ? +select array_union(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); +---- +[1, 2, 3] + +# array_union scalar function #13 +query ? +select array_union([1.2, 3.0], [1.2, 3.0, 5.7]); +---- +[1.2, 3.0, 5.7] + +query ? +select array_union(arrow_cast([1.2, 3.0], 'LargeList(Float64)'), arrow_cast([1.2, 3.0, 5.7], 'LargeList(Float64)')); +---- +[1.2, 3.0, 5.7] + +# array_union scalar function #14 +query ? +select array_union(['hello'], ['hello','datafusion']); +---- +[hello, datafusion] + +query ? +select array_union(arrow_cast(['hello'], 'LargeList(Utf8)'), arrow_cast(['hello','datafusion'], 'LargeList(Utf8)')); +---- +[hello, datafusion] + + +# list_to_string scalar function #4 (function alias `array_to_string`) +query TTT +select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, 4, 5], '-'), list_to_string([1.0, 2.0, 3.0], '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +query TTT +select list_to_string(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), list_to_string(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), list_to_string(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +# array_join scalar function #5 (function alias `array_to_string`) +query TTT +select array_join(['h', 'e', 'l', 'l', 'o'], ','), array_join([1, 2, 3, 4, 5], '-'), array_join([1.0, 2.0, 3.0], '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +query TTT +select array_join(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), array_join(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), array_join(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +# list_join scalar function #6 (function alias `list_join`) +query TTT +select list_join(['h', 'e', 'l', 'l', 'o'], ','), list_join([1, 2, 3, 4, 5], '-'), list_join([1.0, 2.0, 3.0], '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +query TTT +select list_join(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), list_join(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), list_join(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +# array_to_string scalar function with nulls #1 +query TTT +select array_to_string(make_array('h', NULL, 'l', NULL, 'o'), ','), array_to_string(make_array(1, NULL, 3, NULL, 5), '-'), array_to_string(make_array(NULL, 2.0, 3.0), '|'); +---- +h,l,o 1-3-5 2|3 + +query TTT +select array_to_string(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ','), array_to_string(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), '-'), array_to_string(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +# array_to_string scalar function with nulls #2 +query TTT +select array_to_string(make_array('h', NULL, NULL, NULL, 'o'), ',', '-'), array_to_string(make_array(NULL, 2, NULL, 4, 5), '-', 'nil'), array_to_string(make_array(1.0, NULL, 3.0), '|', '0'); +---- +h,-,-,-,o nil-2-nil-4-5 1|0|3 + +query TTT +select array_to_string(arrow_cast(make_array('h', NULL, NULL, NULL, 'o'), 'LargeList(Utf8)'), ',', '-'), array_to_string(arrow_cast(make_array(NULL, 2, NULL, 4, 5), 'LargeList(Int64)'), '-', 'nil'), array_to_string(arrow_cast(make_array(1.0, NULL, 3.0), 'LargeList(Float64)'), '|', '0'); +---- +h,-,-,-,o nil-2-nil-4-5 1|0|3 + +# array_to_string with columns #1 + +# For reference +# select column1, column4 from arrays_values; +# ---- +# [, 2, 3, 4, 5, 6, 7, 8, 9, 10] , +# [11, 12, 13, 14, 15, 16, 17, 18, , 20] . +# [21, 22, 23, , 25, 26, 27, 28, 29, 30] - +# [31, 32, 33, 34, 35, , 37, 38, 39, 40] ok +# NULL @ +# [41, 42, 43, 44, 45, 46, 47, 48, 49, 50] $ +# [51, 52, , 54, 55, 56, 57, 58, 59, 60] ^ +# [61, 62, 63, 64, 65, 66, 67, 68, 69, 70] NULL + +query T select array_to_string(column1, column4) from arrays_values; ---- 2,3,4,5,6,7,8,9,10 @@ -1797,6 +3313,18 @@ NULL 51^52^54^55^56^57^58^59^60 NULL +query T +select array_to_string(column1, column4) from large_arrays_values; +---- +2,3,4,5,6,7,8,9,10 +11.12.13.14.15.16.17.18.20 +21-22-23-25-26-27-28-29-30 +31ok32ok33ok34ok35ok37ok38ok39ok40 +NULL +41$42$43$44$45$46$47$48$49$50 +51^52^54^55^56^57^58^59^60 +NULL + query TT select array_to_string(column1, '_'), array_to_string(make_array(1,2,3), '/') from arrays_values; ---- @@ -1809,6 +3337,18 @@ NULL 1/2/3 51_52_54_55_56_57_58_59_60 1/2/3 61_62_63_64_65_66_67_68_69_70 1/2/3 +query TT +select array_to_string(column1, '_'), array_to_string(make_array(1,2,3), '/') from large_arrays_values; +---- +2_3_4_5_6_7_8_9_10 1/2/3 +11_12_13_14_15_16_17_18_20 1/2/3 +21_22_23_25_26_27_28_29_30 1/2/3 +31_32_33_34_35_37_38_39_40 1/2/3 +NULL 1/2/3 +41_42_43_44_45_46_47_48_49_50 1/2/3 +51_52_54_55_56_57_58_59_60 1/2/3 +61_62_63_64_65_66_67_68_69_70 1/2/3 + query TT select array_to_string(column1, '_', '*'), array_to_string(make_array(make_array(1,2,3)), '.') from arrays_values; ---- @@ -1821,6 +3361,18 @@ NULL 1.2.3 51_52_*_54_55_56_57_58_59_60 1.2.3 61_62_63_64_65_66_67_68_69_70 1.2.3 +query TT +select array_to_string(column1, '_', '*'), array_to_string(make_array(make_array(1,2,3)), '.') from large_arrays_values; +---- +*_2_3_4_5_6_7_8_9_10 1.2.3 +11_12_13_14_15_16_17_18_*_20 1.2.3 +21_22_23_*_25_26_27_28_29_30 1.2.3 +31_32_33_34_35_*_37_38_39_40 1.2.3 +NULL 1.2.3 +41_42_43_44_45_46_47_48_49_50 1.2.3 +51_52_*_54_55_56_57_58_59_60 1.2.3 +61_62_63_64_65_66_67_68_69_70 1.2.3 + ## cardinality # cardinality scalar function @@ -1829,18 +3381,33 @@ select cardinality(make_array(1, 2, 3, 4, 5)), cardinality([1, 3, 5]), cardinali ---- 5 3 5 +query III +select cardinality(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), cardinality(arrow_cast([1, 3, 5], 'LargeList(Int64)')), cardinality(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +5 3 5 + # cardinality scalar function #2 query II select cardinality(make_array([1, 2], [3, 4], [5, 6])), cardinality(array_repeat(array_repeat(array_repeat(3, 3), 2), 3)); ---- 6 18 +query I +select cardinality(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))')); +---- +6 + # cardinality scalar function #3 query II select cardinality(make_array()), cardinality(make_array(make_array())) ---- NULL 0 +query II +select cardinality(arrow_cast(make_array(), 'LargeList(Null)')), cardinality(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +---- +NULL 0 + # cardinality with columns query III select cardinality(column1), cardinality(column2), cardinality(column3) from arrays; @@ -1853,6 +3420,17 @@ NULL 3 4 4 NULL 1 4 3 NULL +query III +select cardinality(column1), cardinality(column2), cardinality(column3) from large_arrays; +---- +4 3 5 +4 3 5 +4 3 5 +4 3 3 +NULL 3 4 +4 NULL 1 +4 3 NULL + ## array_remove (aliases: `list_remove`) # array_remove scalar function #1 @@ -1861,6 +3439,20 @@ select array_remove(make_array(1, 2, 2, 1, 1), 2), array_remove(make_array(1.0, ---- [1, 2, 1, 1] [2.0, 2.0, 1.0, 1.0] [h, e, l, o] +query ??? +select + array_remove(make_array(1, null, 2, 3), 2), + array_remove(make_array(1.1, null, 2.2, 3.3), 1.1), + array_remove(make_array('a', null, 'bc'), 'a'); +---- +[1, , 3] [, 2.2, 3.3] [, bc] + +# TODO: https://github.com/apache/arrow-datafusion/issues/7142 +# query +# select +# array_remove(make_array(1, null, 2), null), +# array_remove(make_array(1, null, 2, null), null); + # array_remove scalar function #2 (element is list) query ?? select array_remove(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_remove(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]); @@ -2031,33 +3623,64 @@ select array_length(make_array(1, 2, 3, 4, 5)), array_length(make_array(1, 2, 3) ---- 5 3 3 +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))')); +---- +5 3 3 + # array_length scalar function #2 query III select array_length(make_array(1, 2, 3, 4, 5), 1), array_length(make_array(1, 2, 3), 1), array_length(make_array([1, 2], [3, 4], [5, 6]), 1); ---- 5 3 3 +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 1), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 1), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'), 1); +---- +5 3 3 + # array_length scalar function #3 query III select array_length(make_array(1, 2, 3, 4, 5), 2), array_length(make_array(1, 2, 3), 2), array_length(make_array([1, 2], [3, 4], [5, 6]), 2); ---- NULL NULL 2 +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 2), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'), 2); +---- +NULL NULL 2 + # array_length scalar function #4 query II select array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 1), array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 2); ---- 3 2 +query II +select array_length(arrow_cast(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 'LargeList(List(List(Int64)))'), 1), array_length(arrow_cast(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 'LargeList(List(List(Int64)))'), 2); +---- +3 2 + # array_length scalar function #5 query III select array_length(make_array()), array_length(make_array(), 1), array_length(make_array(), 2) ---- 0 0 NULL -# list_length scalar function #6 (function alias `array_length`) +# array_length scalar function #6 nested array +query III +select array_length([[1, 2, 3, 4], [5, 6, 7, 8]]), array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 1), array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 2); +---- +2 2 4 + +# list_length scalar function #7 (function alias `array_length`) +query IIII +select list_length(make_array(1, 2, 3, 4, 5)), list_length(make_array(1, 2, 3)), list_length(make_array([1, 2], [3, 4], [5, 6])), array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 3); +---- +5 3 3 NULL + query III -select list_length(make_array(1, 2, 3, 4, 5)), list_length(make_array(1, 2, 3)), list_length(make_array([1, 2], [3, 4], [5, 6])); +select list_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), list_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), list_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))')); ---- 5 3 3 @@ -2074,6 +3697,18 @@ NULL NULL NULL +query I +select array_length(arrow_cast(column1, 'LargeList(Int64)'), column3) from arrays_values; +---- +10 +NULL +NULL +NULL +NULL +NULL +NULL +NULL + # array_length with columns and scalars query II select array_length(array[array[1, 2], array[3, 4]], column3), array_length(column1, 1) from arrays_values; @@ -2087,11 +3722,22 @@ NULL 10 NULL 10 NULL 10 +query II +select array_length(arrow_cast(array[array[1, 2], array[3, 4]], 'LargeList(List(Int64))'), column3), array_length(arrow_cast(column1, 'LargeList(Int64)'), 1) from arrays_values; +---- +2 10 +2 10 +NULL 10 +NULL 10 +NULL NULL +NULL 10 +NULL 10 +NULL 10 + ## array_dims (aliases: `list_dims`) # array dims error -# TODO this is a separate bug -query error Internal error: could not cast value to arrow_array::array::list_array::GenericListArray\. +query error Execution error: array_dims does not support type 'Int64' select array_dims(1); # array_dims scalar function @@ -2100,9 +3746,14 @@ select array_dims(make_array(1, 2, 3)), array_dims(make_array([1, 2], [3, 4])), ---- [3] [2, 2] [1, 1, 1, 2, 1] -# array_dims scalar function #2 -query ?? -select array_dims(array_repeat(array_repeat(array_repeat(2, 3), 2), 1)), array_dims(array_repeat(array_repeat(array_repeat(3, 4), 5), 2)); +query ??? +select array_dims(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), array_dims(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')), array_dims(arrow_cast(make_array([[[[1], [2]]]]), 'LargeList(List(List(List(List(Int64)))))')); +---- +[3] [2, 2] [1, 1, 1, 2, 1] + +# array_dims scalar function #2 +query ?? +select array_dims(array_repeat(array_repeat(array_repeat(2, 3), 2), 1)), array_dims(array_repeat(array_repeat(array_repeat(3, 4), 5), 2)); ---- [1, 2, 3] [2, 5, 4] @@ -2112,12 +3763,22 @@ select array_dims(make_array()), array_dims(make_array(make_array())) ---- NULL [1, 0] +query ?? +select array_dims(arrow_cast(make_array(), 'LargeList(Null)')), array_dims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +---- +NULL [1, 0] + # list_dims scalar function #4 (function alias `array_dims`) query ??? select list_dims(make_array(1, 2, 3)), list_dims(make_array([1, 2], [3, 4])), list_dims(make_array([[[[1], [2]]]])); ---- [3] [2, 2] [1, 1, 1, 2, 1] +query ??? +select list_dims(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), list_dims(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')), list_dims(arrow_cast(make_array([[[[1], [2]]]]), 'LargeList(List(List(List(List(Int64)))))')); +---- +[3] [2, 2] [1, 1, 1, 2, 1] + # array_dims with columns query ??? select array_dims(column1), array_dims(column2), array_dims(column3) from arrays; @@ -2130,13 +3791,84 @@ NULL [3] [4] [2, 2] NULL [1] [2, 2] [3] NULL +query ??? +select array_dims(column1), array_dims(column2), array_dims(column3) from large_arrays; +---- +[2, 2] [3] [5] +[2, 2] [3] [5] +[2, 2] [3] [5] +[2, 2] [3] [3] +NULL [3] [4] +[2, 2] NULL [1] +[2, 2] [3] NULL + + ## array_ndims (aliases: `list_ndims`) # array_ndims scalar function #1 + query III -select array_ndims(make_array(1, 2, 3)), array_ndims(make_array([1, 2], [3, 4])), array_ndims(make_array([[[[1], [2]]]])); +select + array_ndims(1), + array_ndims(null), + array_ndims([2, 3]); ---- -1 2 5 +0 0 1 + +statement ok +CREATE TABLE array_ndims_table +AS VALUES + (1, [1, 2, 3], [[7]], [[[[[10]]]]]), + (2, [4, 5], [[8]], [[[[[10]]]]]), + (null, [6], [[9]], [[[[[10]]]]]), + (3, [6], [[9]], [[[[[10]]]]]) +; + +statement ok +CREATE TABLE large_array_ndims_table +AS SELECT + column1, + arrow_cast(column2, 'LargeList(Int64)') as column2, + arrow_cast(column3, 'LargeList(List(Int64))') as column3, + arrow_cast(column4, 'LargeList(List(List(List(List(Int64)))))') as column4 +FROM array_ndims_table; + +query IIII +select + array_ndims(column1), + array_ndims(column2), + array_ndims(column3), + array_ndims(column4) +from array_ndims_table; +---- +0 1 2 5 +0 1 2 5 +0 1 2 5 +0 1 2 5 + +query IIII +select + array_ndims(column1), + array_ndims(column2), + array_ndims(column3), + array_ndims(column4) +from large_array_ndims_table; +---- +0 1 2 5 +0 1 2 5 +0 1 2 5 +0 1 2 5 + +statement ok +drop table array_ndims_table; + +statement ok +drop table large_array_ndims_table + +query I +select array_ndims(arrow_cast([null], 'List(List(List(Int64)))')); +---- +3 # array_ndims scalar function #2 query II @@ -2148,7 +3880,12 @@ select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ query II select array_ndims(make_array()), array_ndims(make_array(make_array())) ---- -NULL 2 +1 2 + +query II +select array_ndims(arrow_cast(make_array(), 'LargeList(Null)')), array_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +---- +1 2 # list_ndims scalar function #4 (function alias `array_ndims`) query III @@ -2156,10 +3893,20 @@ select list_ndims(make_array(1, 2, 3)), list_ndims(make_array([1, 2], [3, 4])), ---- 1 2 5 +query III +select list_ndims(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), list_ndims(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')), list_ndims(arrow_cast(make_array([[[[1], [2]]]]), 'LargeList(List(List(List(List(Int64)))))')); +---- +1 2 5 + query II -select array_ndims(make_array()), array_ndims(make_array(make_array())) +select list_ndims(make_array()), list_ndims(make_array(make_array())) +---- +1 2 + +query II +select list_ndims(arrow_cast(make_array(), 'LargeList(Null)')), list_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) ---- -NULL 2 +1 2 # array_ndims with columns query III @@ -2173,6 +3920,17 @@ NULL 1 1 2 NULL 1 2 1 NULL +query III +select array_ndims(column1), array_ndims(column2), array_ndims(column3) from large_arrays; +---- +2 1 1 +2 1 1 +2 1 1 +2 1 1 +NULL 1 1 +2 NULL 1 +2 1 NULL + ## array_has/array_has_all/array_has_any query BBBBBBBBBBBB @@ -2190,101 +3948,486 @@ select array_has(make_array(1,2), 1), list_contains(make_array(1,2,3), 0) ; ---- -true true true true true false true false true false true false +true true true true true false true false true false true false + +query BBBBBBBBBBBB +select array_has(arrow_cast(make_array(1,2), 'LargeList(Int64)'), 1), + array_has(arrow_cast(make_array(1,2,NULL), 'LargeList(Int64)'), 1), + array_has(arrow_cast(make_array([2,3], [3,4]), 'LargeList(List(Int64))'), make_array(2,3)), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1], [2,3])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([4,5], [6])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1])), + array_has(arrow_cast(make_array([[[1]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[2]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1], [2]])), + list_has(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 4), + array_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 3), + list_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 0) +; +---- +true true true true true false true false true false true false + +query BBB +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D; +---- +true true true +false false false + +query BBB +select array_has(arrow_cast(column1, 'LargeList(Int64)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Int64)'), arrow_cast(column4, 'LargeList(Int64)')), + array_has_any(arrow_cast(column5, 'LargeList(Int64)'), arrow_cast(column6, 'LargeList(Int64)')) +from array_has_table_1D; +---- +true true true +false false false + +query BBB +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D_Float; +---- +true true false +false false true + +query BBB +select array_has(arrow_cast(column1, 'LargeList(Float64)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Float64)'), arrow_cast(column4, 'LargeList(Float64)')), + array_has_any(arrow_cast(column5, 'LargeList(Float64)'), arrow_cast(column6, 'LargeList(Float64)')) +from array_has_table_1D_Float; +---- +true true false +false false true + +query BBB +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D_Boolean; +---- +false true true +true true true + +query BBB +select array_has(arrow_cast(column1, 'LargeList(Boolean)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Boolean)'), arrow_cast(column4, 'LargeList(Boolean)')), + array_has_any(arrow_cast(column5, 'LargeList(Boolean)'), arrow_cast(column6, 'LargeList(Boolean)')) +from array_has_table_1D_Boolean; +---- +false true true +true true true + +query BBB +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D_UTF8; +---- +true true false +false false true + +query BBB +select array_has(arrow_cast(column1, 'LargeList(Utf8)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Utf8)'), arrow_cast(column4, 'LargeList(Utf8)')), + array_has_any(arrow_cast(column5, 'LargeList(Utf8)'), arrow_cast(column6, 'LargeList(Utf8)')) +from array_has_table_1D_UTF8; +---- +true true false +false false true + +query BB +select array_has(column1, column2), + array_has_all(column3, column4) +from array_has_table_2D; +---- +false true +true false + +query BB +select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), column2), + array_has_all(arrow_cast(column3, 'LargeList(List(Int64))'), arrow_cast(column4, 'LargeList(List(Int64))')) +from array_has_table_2D; +---- +false true +true false + +query B +select array_has_all(column1, column2) +from array_has_table_2D_float; +---- +true +false + +query B +select array_has_all(arrow_cast(column1, 'LargeList(List(Float64))'), arrow_cast(column2, 'LargeList(List(Float64))')) +from array_has_table_2D_float; +---- +true +false + +query B +select array_has(column1, column2) from array_has_table_3D; +---- +false +true +false +false +true +false +true + +query B +select array_has(arrow_cast(column1, 'LargeList(List(List(Int64)))'), column2) from array_has_table_3D; +---- +false +true +false +false +true +false +true + +query BBBB +select array_has(column1, make_array(5, 6)), + array_has(column1, make_array(7, NULL)), + array_has(column2, 5.5), + array_has(column3, 'o') +from arrays; +---- +false false false true +true false true false +true false false true +false true false false +false false false false +false false false false + +query BBBB +select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5, 6)), + array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(7, NULL)), + array_has(arrow_cast(column2, 'LargeList(Float64)'), 5.5), + array_has(arrow_cast(column3, 'LargeList(Utf8)'), 'o') +from arrays; +---- +false false false true +true false true false +true false false true +false true false false +false false false false +false false false false + +query BBBBBBBBBBBBB +select array_has_all(make_array(1,2,3), make_array(1,3)), + array_has_all(make_array(1,2,3), make_array(1,4)), + array_has_all(make_array([1,2], [3,4]), make_array([1,2])), + array_has_all(make_array([1,2], [3,4]), make_array([1,3])), + array_has_all(make_array([1,2], [3,4]), make_array([1,2], [3,4], [5,6])), + array_has_all(make_array([[1,2,3]]), make_array([[1]])), + array_has_all(make_array([[1,2,3]]), make_array([[1,2,3]])), + array_has_any(make_array(1,2,3), make_array(1,10,100)), + array_has_any(make_array(1,2,3), make_array(10,100)), + array_has_any(make_array([1,2], [3,4]), make_array([1,10], [10,4])), + array_has_any(make_array([1,2], [3,4]), make_array([10,20], [3,4])), + array_has_any(make_array([[1,2,3]]), make_array([[1,2,3], [4,5,6]])), + array_has_any(make_array([[1,2,3]]), make_array([[1,2,3]], [[4,5,6]])) +; +---- +true false true false false false true true false false true false true + +query BBBBBBBBBBBBB +select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))')) +; +---- +true false true false false false true true false false true false true + +query BBBBBBBBBBBBB +select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))')) +; +---- +true false true false false false true true false false true false true + +## array_distinct + +query ? +select array_distinct(null); +---- +NULL + +query ? +select array_distinct([]); +---- +[] + +query ? +select array_distinct([[], []]); +---- +[[]] + +query ? +select array_distinct(column1) +from array_distinct_table_1D; +---- +[1, 2, 3] +[1, 2, 3, 4, 5] +[3, 5] + +query ? +select array_distinct(column1) +from array_distinct_table_1D_UTF8; +---- +[a, bc, def] +[a, bc, def, defg] +[defg] + +query ? +select array_distinct(column1) +from array_distinct_table_2D; +---- +[[1, 2], [3, 4], [5, 6]] +[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] +[, [5, 6]] + +query ? +select array_distinct(column1) +from array_distinct_table_1D_large; +---- +[1, 2, 3] +[1, 2, 3, 4, 5] +[3, 5] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D; +---- +[1] [1, 3] [1, 3] +[11] [11, 33] [11, 33] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from large_array_intersect_table_1D; +---- +[1] [1, 3] [1, 3] +[11] [11, 33] [11, 33] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D_Float; +---- +[1.0] [1.0, 3.0] [] +[] [2.0] [1.11] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D_Boolean; +---- +[] [false, true] [false] +[false] [true] [true] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from large_array_intersect_table_1D_Boolean; +---- +[] [false, true] [false] +[false] [true] [true] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D_UTF8; +---- +[bc] [arrow, rust] [] +[] [arrow, datafusion, rust] [arrow, rust] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from large_array_intersect_table_1D_UTF8; +---- +[bc] [arrow, rust] [] +[] [arrow, datafusion, rust] [arrow, rust] + +query ?? +select array_intersect(column1, column2), + array_intersect(column3, column4) +from array_intersect_table_2D; +---- +[] [[4, 5], [6, 7]] +[[3, 4]] [[5, 6, 7], [8, 9, 10]] + +query ?? +select array_intersect(column1, column2), + array_intersect(column3, column4) +from large_array_intersect_table_2D; +---- +[] [[4, 5], [6, 7]] +[[3, 4]] [[5, 6, 7], [8, 9, 10]] + + +query ? +select array_intersect(column1, column2) +from array_intersect_table_2D_float; +---- +[[1.1, 2.2], [3.3]] +[[1.1, 2.2], [3.3]] + +query ? +select array_intersect(column1, column2) +from large_array_intersect_table_2D_float; +---- +[[1.1, 2.2], [3.3]] +[[1.1, 2.2], [3.3]] + +query ? +select array_intersect(column1, column2) +from array_intersect_table_3D; +---- +[] +[[[1, 2]]] + +query ? +select array_intersect(column1, column2) +from large_array_intersect_table_3D; +---- +[] +[[[1, 2]]] + +query ?????? +SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), + array_intersect(make_array(1,3,5), make_array(2,4,6)), + array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + array_intersect(make_array(true, false), make_array(true)), + array_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), + array_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) +; +---- +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] + +query ?????? +SELECT array_intersect(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(2,3,4), 'LargeList(Int64)')), + array_intersect(arrow_cast(make_array(1,3,5), 'LargeList(Int64)'), arrow_cast(make_array(2,4,6), 'LargeList(Int64)')), + array_intersect(arrow_cast(make_array('aa','bb','cc'), 'LargeList(Utf8)'), arrow_cast(make_array('cc','aa','dd'), 'LargeList(Utf8)')), + array_intersect(arrow_cast(make_array(true, false), 'LargeList(Boolean)'), arrow_cast(make_array(true), 'LargeList(Boolean)')), + array_intersect(arrow_cast(make_array(1.1, 2.2, 3.3), 'LargeList(Float64)'), arrow_cast(make_array(2.2, 3.3, 4.4), 'LargeList(Float64)')), + array_intersect(arrow_cast(make_array([1, 1], [2, 2], [3, 3]), 'LargeList(List(Int64))'), arrow_cast(make_array([2, 2], [3, 3], [4, 4]), 'LargeList(List(Int64))')) +; +---- +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] + +query ? +select array_intersect([], []); +---- +[] + +query ? +select array_intersect(arrow_cast([], 'LargeList(Null)'), arrow_cast([], 'LargeList(Null)')); +---- +[] + +query ? +select array_intersect([1, 1, 2, 2, 3, 3], null); +---- +[] + +query ? +select array_intersect(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null); +---- +[] -query BBB -select array_has(column1, column2), - array_has_all(column3, column4), - array_has_any(column5, column6) -from array_has_table_1D; +query ? +select array_intersect(null, [1, 1, 2, 2, 3, 3]); ---- -true true true -false false false +NULL -query BBB -select array_has(column1, column2), - array_has_all(column3, column4), - array_has_any(column5, column6) -from array_has_table_1D_Float; +query ? +select array_intersect(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); ---- -true true false -false false true +NULL -query BBB -select array_has(column1, column2), - array_has_all(column3, column4), - array_has_any(column5, column6) -from array_has_table_1D_Boolean; +query ? +select array_intersect([], null); ---- -false true true -true true true +[] -query BBB -select array_has(column1, column2), - array_has_all(column3, column4), - array_has_any(column5, column6) -from array_has_table_1D_UTF8; +query ? +select array_intersect(arrow_cast([], 'LargeList(Null)'), null); ---- -true true false -false false true +[] -query BB -select array_has(column1, column2), - array_has_all(column3, column4) -from array_has_table_2D; +query ? +select array_intersect(null, []); ---- -false true -true false +NULL -query B -select array_has_all(column1, column2) -from array_has_table_2D_float; +query ? +select array_intersect(null, arrow_cast([], 'LargeList(Null)')); ---- -true -false +NULL -query B -select array_has(column1, column2) from array_has_table_3D; +query ? +select array_intersect(null, null); ---- -false -true -false -false -true -false -true +NULL -query BBBB -select array_has(column1, make_array(5, 6)), - array_has(column1, make_array(7, NULL)), - array_has(column2, 5.5), - array_has(column3, 'o') -from arrays; +query ?????? +SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)), + list_intersect(make_array(1,3,5), make_array(2,4,6)), + list_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + list_intersect(make_array(true, false), make_array(true)), + list_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), + list_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) +; ---- -false false false true -true false true false -true false false true -false true false false -false false false false -false false false false +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] -query BBBBBBBBBBBBB -select array_has_all(make_array(1,2,3), make_array(1,3)), - array_has_all(make_array(1,2,3), make_array(1,4)), - array_has_all(make_array([1,2], [3,4]), make_array([1,2])), - array_has_all(make_array([1,2], [3,4]), make_array([1,3])), - array_has_all(make_array([1,2], [3,4]), make_array([1,2], [3,4], [5,6])), - array_has_all(make_array([[1,2,3]]), make_array([[1]])), - array_has_all(make_array([[1,2,3]]), make_array([[1,2,3]])), - array_has_any(make_array(1,2,3), make_array(1,10,100)), - array_has_any(make_array(1,2,3), make_array(10,100)), - array_has_any(make_array([1,2], [3,4]), make_array([1,10], [10,4])), - array_has_any(make_array([1,2], [3,4]), make_array([10,20], [3,4])), - array_has_any(make_array([[1,2,3]]), make_array([[1,2,3], [4,5,6]])), - array_has_any(make_array([[1,2,3]]), make_array([[1,2,3]], [[4,5,6]])) +query ?????? +SELECT list_intersect(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(2,3,4), 'LargeList(Int64)')), + list_intersect(arrow_cast(make_array(1,3,5), 'LargeList(Int64)'), arrow_cast(make_array(2,4,6), 'LargeList(Int64)')), + list_intersect(arrow_cast(make_array('aa','bb','cc'), 'LargeList(Utf8)'), arrow_cast(make_array('cc','aa','dd'), 'LargeList(Utf8)')), + list_intersect(arrow_cast(make_array(true, false), 'LargeList(Boolean)'), arrow_cast(make_array(true), 'LargeList(Boolean)')), + list_intersect(arrow_cast(make_array(1.1, 2.2, 3.3), 'LargeList(Float64)'), arrow_cast(make_array(2.2, 3.3, 4.4), 'LargeList(Float64)')), + list_intersect(arrow_cast(make_array([1, 1], [2, 2], [3, 3]), 'LargeList(List(Int64))'), arrow_cast(make_array([2, 2], [3, 3], [4, 4]), 'LargeList(List(Int64))')) ; ---- -true false true false false false true true false false true false true +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] query BBBB select list_has_all(make_array(1,2,3), make_array(4,5,6)), @@ -2295,6 +4438,161 @@ select list_has_all(make_array(1,2,3), make_array(4,5,6)), ---- false true false true +query ??? +select range(column2), + range(column1, column2), + range(column1, column2, column3) +from arrays_range; +---- +[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] [3, 4, 5, 6, 7, 8, 9] [3, 5, 7, 9] +[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] [4, 5, 6, 7, 8, 9, 10, 11, 12] [4, 7, 10] + +query ?????? +select range(5), + range(2, 5), + range(2, 10, 3), + range(1, 5, -1), + range(1, -5, 1), + range(1, -5, -1) +; +---- +[0, 1, 2, 3, 4] [2, 3, 4] [2, 5, 8] [] [] [1, 0, -1, -2, -3, -4] + +query ??? +select generate_series(5), + generate_series(2, 5), + generate_series(2, 10, 3) +; +---- +[0, 1, 2, 3, 4] [2, 3, 4] [2, 5, 8] + +## array_except + +statement ok +CREATE TABLE array_except_table +AS VALUES + ([1, 2, 2, 3], [2, 3, 4]), + ([2, 3, 3], [3]), + ([3], [3, 3, 4]), + (null, [3, 4]), + ([1, 2], null), + (null, null) +; + +query ? +select array_except(column1, column2) from array_except_table; +---- +[1] +[2] +[] +NULL +[1, 2] +NULL + +statement ok +drop table array_except_table; + +statement ok +CREATE TABLE array_except_nested_list_table +AS VALUES + ([[1, 2], [3]], [[2], [3], [4, 5]]), + ([[1, 2], [3]], [[2], [1, 2]]), + ([[1, 2], [3]], null), + (null, [[1], [2, 3], [4, 5, 6]]), + ([[1], [2, 3], [4, 5, 6]], [[2, 3], [4, 5, 6], [1]]) +; + +query ? +select array_except(column1, column2) from array_except_nested_list_table; +---- +[[1, 2]] +[[3]] +[[1, 2], [3]] +NULL +[] + +statement ok +drop table array_except_nested_list_table; + +statement ok +CREATE TABLE array_except_table_float +AS VALUES + ([1.1, 2.2, 3.3], [2.2]), + ([1.1, 2.2, 3.3], [4.4]), + ([1.1, 2.2, 3.3], [3.3, 2.2, 1.1]) +; + +query ? +select array_except(column1, column2) from array_except_table_float; +---- +[1.1, 3.3] +[1.1, 2.2, 3.3] +[] + +statement ok +drop table array_except_table_float; + +statement ok +CREATE TABLE array_except_table_ut8 +AS VALUES + (['a', 'b', 'c'], ['a']), + (['a', 'bc', 'def'], ['g', 'def']), + (['a', 'bc', 'def'], null), + (null, ['a']) +; + +query ? +select array_except(column1, column2) from array_except_table_ut8; +---- +[b, c] +[a, bc] +[a, bc, def] +NULL + +statement ok +drop table array_except_table_ut8; + +statement ok +CREATE TABLE array_except_table_bool +AS VALUES + ([true, false, false], [false]), + ([true, true, true], [false]), + ([false, false, false], [true]), + ([true, false], null), + (null, [true, false]) +; + +query ? +select array_except(column1, column2) from array_except_table_bool; +---- +[true] +[true] +[false] +[true, false] +NULL + +statement ok +drop table array_except_table_bool; + +query ? +select array_except([], null); +---- +[] + +query ? +select array_except([], []); +---- +[] + +query ? +select array_except(null, []); +---- +NULL + +query ? +select array_except(null, null) +---- +NULL ### Array operators tests @@ -2319,10 +4617,49 @@ select 1 || make_array(2, 3, 4), 1.0 || make_array(2.0, 3.0, 4.0), 'h' || make_a ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +# array concatenate operator with scalars #4 (mixed) +query ? +select 0 || [1,2,3] || 4 || [5] || [6,7]; +---- +[0, 1, 2, 3, 4, 5, 6, 7] + +# array concatenate operator with nd-list #5 (mixed) +query ? +select 0 || [1,2,3] || [[4,5]] || [[6,7,8]] || [9,10]; +---- +[[0, 1, 2, 3], [4, 5], [6, 7, 8], [9, 10]] + +# array concatenate operator non-valid cases +## concat 2D with scalar is not valid +query error +select 0 || [1,2,3] || [[4,5]] || [[6,7,8]] || [9,10] || 11; + +## concat scalar with 2D is not valid +query error +select 0 || [[1,2,3]]; + +# array concatenate operator with column + +statement ok +CREATE TABLE array_concat_operator_table +AS VALUES + (0, [1, 2, 2, 3], 4, [5, 6, 5]), + (-1, [4, 5, 6], 7, [8, 1, 1]) +; + +query ? +select column1 || column2 || column3 || column4 from array_concat_operator_table; +---- +[0, 1, 2, 2, 3, 4, 5, 6, 5] +[-1, 4, 5, 6, 7, 8, 1, 1] + +statement ok +drop table array_concat_operator_table; + ## array containment operator # array containment operator with scalars #1 (at arrow) -query ??????? +query BBBBBBB select make_array(1,2,3) @> make_array(1,3), make_array(1,2,3) @> make_array(1,4), make_array([1,2], [3,4]) @> make_array([1,2]), @@ -2334,7 +4671,7 @@ select make_array(1,2,3) @> make_array(1,3), true false true false false false true # array containment operator with scalars #2 (arrow at) -query ??????? +query BBBBBBB select make_array(1,3) <@ make_array(1,2,3), make_array(1,4) <@ make_array(1,2,3), make_array([1,2]) <@ make_array([1,2], [3,4]), @@ -2455,17 +4792,32 @@ select empty(make_array(1)); ---- false +query B +select empty(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +false + # empty scalar function #2 query B select empty(make_array()); ---- true +query B +select empty(arrow_cast(make_array(), 'LargeList(Null)')); +---- +true + # empty scalar function #3 query B select empty(make_array(NULL)); ---- -true +false + +query B +select empty(arrow_cast(make_array(NULL), 'LargeList(Null)')); +---- +false # empty scalar function #4 query B @@ -2485,6 +4837,17 @@ NULL false false +query B +select empty(arrow_cast(column1, 'LargeList(List(Int64))')) from arrays; +---- +false +false +false +false +NULL +false +false + query ? SELECT string_to_array('abcxxxdef', 'xxx') ---- @@ -2545,6 +4908,9 @@ drop table nested_arrays; statement ok drop table arrays; +statement ok +drop table large_arrays; + statement ok drop table slices; @@ -2578,14 +4944,65 @@ drop table array_has_table_2D_float; statement ok drop table array_has_table_3D; +statement ok +drop table array_intersect_table_1D; + +statement ok +drop table large_array_intersect_table_1D; + +statement ok +drop table array_intersect_table_1D_Float; + +statement ok +drop table large_array_intersect_table_1D_Float; + +statement ok +drop table array_intersect_table_1D_Boolean; + +statement ok +drop table large_array_intersect_table_1D_Boolean; + +statement ok +drop table array_intersect_table_1D_UTF8; + +statement ok +drop table large_array_intersect_table_1D_UTF8; + +statement ok +drop table array_intersect_table_2D; + +statement ok +drop table large_array_intersect_table_2D; + +statement ok +drop table array_intersect_table_2D_float; + +statement ok +drop table large_array_intersect_table_2D_float; + +statement ok +drop table array_intersect_table_3D; + +statement ok +drop table large_array_intersect_table_3D; + statement ok drop table arrays_values_without_nulls; +statement ok +drop table arrays_range; + statement ok drop table arrays_with_repeating_elements; +statement ok +drop table large_arrays_with_repeating_elements; + statement ok drop table nested_arrays_with_repeating_elements; +statement ok +drop table large_nested_arrays_with_repeating_elements; + statement ok drop table flatten_table; diff --git a/datafusion/sqllogictest/test_files/arrow_files.slt b/datafusion/sqllogictest/test_files/arrow_files.slt new file mode 100644 index 000000000000..5c1b6fb726ed --- /dev/null +++ b/datafusion/sqllogictest/test_files/arrow_files.slt @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Arrow Files Format support +############# + + +statement ok + +CREATE EXTERNAL TABLE arrow_simple +STORED AS ARROW +LOCATION '../core/tests/data/example.arrow'; + + +# physical plan +query TT +EXPLAIN SELECT * FROM arrow_simple +---- +logical_plan TableScan: arrow_simple projection=[f0, f1, f2] +physical_plan ArrowExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.arrow]]}, projection=[f0, f1, f2] + +# correct content +query ITB +SELECT * FROM arrow_simple +---- +1 foo true +2 bar NULL +3 baz false +4 NULL true diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index e485251b7342..6a623e6c92f9 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -338,3 +338,41 @@ select arrow_cast(timestamp '2000-01-01T00:00:00Z', 'Timestamp(Nanosecond, Some( statement error Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone select arrow_cast(timestamp '2000-01-01T00:00:00', 'Timestamp(Nanosecond, Some( "+25:00" ))'); + + +## List + + +query ? +select arrow_cast('1', 'List(Int64)'); +---- +[1] + +query ? +select arrow_cast(make_array(1, 2, 3), 'List(Int64)'); +---- +[1, 2, 3] + +query T +select arrow_typeof(arrow_cast(make_array(1, 2, 3), 'List(Int64)')); +---- +List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + + +## LargeList + + +query ? +select arrow_cast('1', 'LargeList(Int64)'); +---- +[1] + +query ? +select arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'); +---- +[1, 2, 3] + +query T +select arrow_typeof(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')); +---- +LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) diff --git a/datafusion/sqllogictest/test_files/avro.slt b/datafusion/sqllogictest/test_files/avro.slt index ede11406e1a9..3f21274c009f 100644 --- a/datafusion/sqllogictest/test_files/avro.slt +++ b/datafusion/sqllogictest/test_files/avro.slt @@ -34,6 +34,78 @@ STORED AS AVRO WITH HEADER ROW LOCATION '../../testing/data/avro/alltypes_plain.avro' +statement ok +CREATE EXTERNAL TABLE alltypes_plain_snappy ( + id INT NOT NULL, + bool_col BOOLEAN NOT NULL, + tinyint_col TINYINT NOT NULL, + smallint_col SMALLINT NOT NULL, + int_col INT NOT NULL, + bigint_col BIGINT NOT NULL, + float_col FLOAT NOT NULL, + double_col DOUBLE NOT NULL, + date_string_col BYTEA NOT NULL, + string_col VARCHAR NOT NULL, + timestamp_col TIMESTAMP NOT NULL, +) +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/alltypes_plain.snappy.avro' + +statement ok +CREATE EXTERNAL TABLE alltypes_plain_bzip2 ( + id INT NOT NULL, + bool_col BOOLEAN NOT NULL, + tinyint_col TINYINT NOT NULL, + smallint_col SMALLINT NOT NULL, + int_col INT NOT NULL, + bigint_col BIGINT NOT NULL, + float_col FLOAT NOT NULL, + double_col DOUBLE NOT NULL, + date_string_col BYTEA NOT NULL, + string_col VARCHAR NOT NULL, + timestamp_col TIMESTAMP NOT NULL, +) +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/alltypes_plain.bzip2.avro' + +statement ok +CREATE EXTERNAL TABLE alltypes_plain_xz ( + id INT NOT NULL, + bool_col BOOLEAN NOT NULL, + tinyint_col TINYINT NOT NULL, + smallint_col SMALLINT NOT NULL, + int_col INT NOT NULL, + bigint_col BIGINT NOT NULL, + float_col FLOAT NOT NULL, + double_col DOUBLE NOT NULL, + date_string_col BYTEA NOT NULL, + string_col VARCHAR NOT NULL, + timestamp_col TIMESTAMP NOT NULL, +) +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/alltypes_plain.xz.avro' + +statement ok +CREATE EXTERNAL TABLE alltypes_plain_zstandard ( + id INT NOT NULL, + bool_col BOOLEAN NOT NULL, + tinyint_col TINYINT NOT NULL, + smallint_col SMALLINT NOT NULL, + int_col INT NOT NULL, + bigint_col BIGINT NOT NULL, + float_col FLOAT NOT NULL, + double_col DOUBLE NOT NULL, + date_string_col BYTEA NOT NULL, + string_col VARCHAR NOT NULL, + timestamp_col TIMESTAMP NOT NULL, +) +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/alltypes_plain.zstandard.avro' + statement ok CREATE EXTERNAL TABLE single_nan ( mycol FLOAT @@ -73,6 +145,58 @@ SELECT id, CAST(string_col AS varchar) FROM alltypes_plain 0 0 1 1 +# test avro query with snappy +query IT +SELECT id, CAST(string_col AS varchar) FROM alltypes_plain_snappy +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + +# test avro query with bzip2 +query IT +SELECT id, CAST(string_col AS varchar) FROM alltypes_plain_bzip2 +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + +# test avro query with xz +query IT +SELECT id, CAST(string_col AS varchar) FROM alltypes_plain_xz +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + +# test avro query with zstandard +query IT +SELECT id, CAST(string_col AS varchar) FROM alltypes_plain_zstandard +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + # test avro single nan schema query R SELECT mycol FROM single_nan @@ -101,11 +225,11 @@ SELECT id, CAST(string_col AS varchar) FROM alltypes_plain_multi_files 1 1 # test avro nested records -query ?? -SELECT f1, f2 FROM nested_records +query ???? +SELECT f1, f2, f3, f4 FROM nested_records ---- -{ns2.record2.f1_1: aaa, ns2.record2.f1_2: 10, ns2.record2.f1_3: {ns3.record3.f1_3_1: 3.14}} [{ns4.record4.f2_1: true, ns4.record4.f2_2: 1.2}, {ns4.record4.f2_1: true, ns4.record4.f2_2: 2.2}] -{ns2.record2.f1_1: bbb, ns2.record2.f1_2: 20, ns2.record2.f1_3: {ns3.record3.f1_3_1: 3.14}} [{ns4.record4.f2_1: false, ns4.record4.f2_2: 10.2}] +{f1_1: aaa, f1_2: 10, f1_3: {f1_3_1: 3.14}} [{f2_1: true, f2_2: 1.2}, {f2_1: true, f2_2: 2.2}] {f3_1: xyz} [{f4_1: 200}, ] +{f1_1: bbb, f1_2: 20, f1_3: {f1_3_1: 3.14}} [{f2_1: false, f2_2: 10.2}] NULL [, {f4_1: 300}] # test avro enum query TTT @@ -129,10 +253,10 @@ EXPLAIN SELECT count(*) from alltypes_plain ---- logical_plan Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ---TableScan: alltypes_plain projection=[id] +--TableScan: alltypes_plain projection=[] physical_plan AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] --CoalescePartitionsExec ----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] ------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------AvroExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/avro/alltypes_plain.avro]]}, projection=[id] +--------AvroExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/avro/alltypes_plain.avro]]} diff --git a/datafusion/sqllogictest/test_files/binary.slt b/datafusion/sqllogictest/test_files/binary.slt index d3a7e8c19334..0568ada3ad7d 100644 --- a/datafusion/sqllogictest/test_files/binary.slt +++ b/datafusion/sqllogictest/test_files/binary.slt @@ -155,3 +155,113 @@ drop table t_source statement ok drop table t + + +############# +## Tests for binary that contains strings +############# + +statement ok +CREATE TABLE t_source +AS VALUES + ('Foo'), + (NULL), + ('Bar'), + ('FooBar') +; + +# Create a table with Binary, LargeBinary but really has strings +statement ok +CREATE TABLE t +AS SELECT + arrow_cast(column1, 'Binary') as "binary", + arrow_cast(column1, 'LargeBinary') as "largebinary" +FROM t_source; + +query ??TT +SELECT binary, largebinary, cast(binary as varchar) as binary_str, cast(largebinary as varchar) as binary_largestr from t; +---- +466f6f 466f6f Foo Foo +NULL NULL NULL NULL +426172 426172 Bar Bar +466f6f426172 466f6f426172 FooBar FooBar + +# ensure coercion works for = and <> +query ?T +SELECT binary, cast(binary as varchar) as str FROM t WHERE binary = 'Foo'; +---- +466f6f Foo + +query ?T +SELECT binary, cast(binary as varchar) as str FROM t WHERE binary <> 'Foo'; +---- +426172 Bar +466f6f426172 FooBar + +# order by +query ? +SELECT binary FROM t ORDER BY binary; +---- +426172 +466f6f +466f6f426172 +NULL + +# order by +query ? +SELECT largebinary FROM t ORDER BY largebinary; +---- +426172 +466f6f +466f6f426172 +NULL + +# LIKE +query ? +SELECT binary FROM t where binary LIKE '%F%'; +---- +466f6f +466f6f426172 + +query ? +SELECT largebinary FROM t where largebinary LIKE '%F%'; +---- +466f6f +466f6f426172 + +# character_length function +query TITI +SELECT + cast(binary as varchar) as str, + character_length(binary) as binary_len, + cast(largebinary as varchar) as large_str, + character_length(binary) as largebinary_len +from t; +---- +Foo 3 Foo 3 +NULL NULL NULL NULL +Bar 3 Bar 3 +FooBar 6 FooBar 6 + +query I +SELECT character_length(X'20'); +---- +1 + +# still errors on values that can not be coerced to utf8 +query error Encountered non UTF\-8 data: invalid utf\-8 sequence of 1 bytes from index 0 +SELECT character_length(X'c328'); + +# regexp_replace +query TTTT +SELECT + cast(binary as varchar) as str, + regexp_replace(binary, 'F', 'f') as binary_replaced, + cast(largebinary as varchar) as large_str, + regexp_replace(largebinary, 'F', 'f') as large_binary_replaced +from t; +---- +Foo foo Foo foo +NULL NULL NULL NULL +Bar Bar Bar Bar +FooBar fooBar FooBar fooBar diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index a41d1fca66a4..89b23917884c 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -32,8 +32,8 @@ logical_plan CopyTo: format=parquet output_url=test_files/scratch/copy/table single_file_output=false options: (compression 'zstd(10)') --TableScan: source_table projection=[col1, col2] physical_plan -InsertExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) ---MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +FileSinkExec: sink=ParquetSink(file_groups=[]) +--MemoryExec: partitions=1, partition_sizes=[1] # Error case query error DataFusion error: Invalid or Unsupported Configuration: Format not explicitly set and unable to get file extension! @@ -66,8 +66,8 @@ select * from validate_parquet; # Copy parquet with all supported statment overrides query IT -COPY source_table -TO 'test_files/scratch/copy/table_with_options' +COPY source_table +TO 'test_files/scratch/copy/table_with_options' (format parquet, single_file_output false, compression snappy, @@ -206,11 +206,11 @@ select * from validate_single_json; # COPY csv files with all options set query IT -COPY source_table -to 'test_files/scratch/copy/table_csv_with_options' -(format csv, -single_file_output false, -header false, +COPY source_table +to 'test_files/scratch/copy/table_csv_with_options' +(format csv, +single_file_output false, +header false, compression 'uncompressed', datetime_format '%FT%H:%M:%S.%9f', delimiter ';', @@ -220,8 +220,8 @@ null_value 'NULLVAL'); # Validate single csv output statement ok -CREATE EXTERNAL TABLE validate_csv_with_options -STORED AS csv +CREATE EXTERNAL TABLE validate_csv_with_options +STORED AS csv LOCATION 'test_files/scratch/copy/table_csv_with_options'; query T @@ -230,6 +230,62 @@ select * from validate_csv_with_options; 1;Foo 2;Bar +# Copy from table to single arrow file +query IT +COPY source_table to 'test_files/scratch/copy/table.arrow'; +---- +2 + +# Validate single csv output +statement ok +CREATE EXTERNAL TABLE validate_arrow_file +STORED AS arrow +LOCATION 'test_files/scratch/copy/table.arrow'; + +query IT +select * from validate_arrow_file; +---- +1 Foo +2 Bar + +# Copy from dict encoded values to single arrow file +query T? +COPY (values +('c', arrow_cast('foo', 'Dictionary(Int32, Utf8)')), ('d', arrow_cast('bar', 'Dictionary(Int32, Utf8)'))) +to 'test_files/scratch/copy/table_dict.arrow'; +---- +2 + +# Validate single csv output +statement ok +CREATE EXTERNAL TABLE validate_arrow_file_dict +STORED AS arrow +LOCATION 'test_files/scratch/copy/table_dict.arrow'; + +query T? +select * from validate_arrow_file_dict; +---- +c foo +d bar + + +# Copy from table to folder of json +query IT +COPY source_table to 'test_files/scratch/copy/table_arrow' (format arrow, single_file_output false); +---- +2 + +# Validate json output +statement ok +CREATE EXTERNAL TABLE validate_arrow STORED AS arrow LOCATION 'test_files/scratch/copy/table_arrow'; + +query IT +select * from validate_arrow; +---- +1 Foo +2 Bar + + # Error cases: # Copy from table with options diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt new file mode 100644 index 000000000000..9facb064bf32 --- /dev/null +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# create_external_table_with_quote_escape +statement ok +CREATE EXTERNAL TABLE csv_with_quote ( +c1 VARCHAR, +c2 VARCHAR +) STORED AS CSV +WITH HEADER ROW +DELIMITER ',' +OPTIONS ('quote' '~') +LOCATION '../core/tests/data/quote.csv'; + +statement ok +CREATE EXTERNAL TABLE csv_with_escape ( +c1 VARCHAR, +c2 VARCHAR +) STORED AS CSV +WITH HEADER ROW +DELIMITER ',' +OPTIONS ('escape' '\"') +LOCATION '../core/tests/data/escape.csv'; + +query TT +select * from csv_with_quote; +---- +id0 value0 +id1 value1 +id2 value2 +id3 value3 +id4 value4 +id5 value5 +id6 value6 +id7 value7 +id8 value8 +id9 value9 + +query TT +select * from csv_with_escape; +---- +id0 value"0 +id1 value"1 +id2 value"2 +id3 value"3 +id4 value"4 +id5 value"5 +id6 value"6 +id7 value"7 +id8 value"8 +id9 value"9 diff --git a/datafusion/sqllogictest/test_files/ddl.slt b/datafusion/sqllogictest/test_files/ddl.slt index ed4f4b4a11ac..682972b5572a 100644 --- a/datafusion/sqllogictest/test_files/ddl.slt +++ b/datafusion/sqllogictest/test_files/ddl.slt @@ -750,7 +750,7 @@ query TT explain select c1 from t; ---- logical_plan TableScan: t projection=[c1] -physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/empty.csv]]}, projection=[c1], infinite_source=true, has_header=true +physical_plan StreamingTableExec: partition_sizes=1, projection=[c1], infinite_source=true statement ok drop table t; diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index a326a0cc4941..c220a5fc9a52 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -507,27 +507,26 @@ select * from decimal_simple where c1 >= 0.00004 order by c1; query RRIBR -select * from decimal_simple where c1 >= 0.00004 order by c1 limit 10; +select * from decimal_simple where c1 >= 0.00004 order by c1, c3 limit 10; ---- 0.00004 0.000000000004 5 true 0.000044 +0.00004 0.000000000004 8 false 0.000044 0.00004 0.000000000004 12 false 0.00004 0.00004 0.000000000004 14 true 0.00004 -0.00004 0.000000000004 8 false 0.000044 -0.00005 0.000000000005 9 true 0.000052 +0.00005 0.000000000005 1 false 0.0001 0.00005 0.000000000005 4 true 0.000078 0.00005 0.000000000005 8 false 0.000033 +0.00005 0.000000000005 9 true 0.000052 0.00005 0.000000000005 100 true 0.000068 -0.00005 0.000000000005 1 false 0.0001 - query RRIBR -select * from decimal_simple where c1 >= 0.00004 order by c1 limit 5; +select * from decimal_simple where c1 >= 0.00004 order by c1, c3 limit 5; ---- 0.00004 0.000000000004 5 true 0.000044 +0.00004 0.000000000004 8 false 0.000044 0.00004 0.000000000004 12 false 0.00004 0.00004 0.000000000004 14 true 0.00004 -0.00004 0.000000000004 8 false 0.000044 -0.00005 0.000000000005 9 true 0.000052 +0.00005 0.000000000005 1 false 0.0001 query RRIBR @@ -623,8 +622,103 @@ create table t as values (arrow_cast(123, 'Decimal256(5,2)')); statement ok set datafusion.execution.target_partitions = 1; -query error DataFusion error: This feature is not implemented: AvgAccumulator for \(Decimal256\(5, 2\) --> Decimal256\(9, 6\)\) +query R select AVG(column1) from t; +---- +123 statement ok drop table t; + +statement ok +CREATE EXTERNAL TABLE decimal256_simple ( +c1 DECIMAL(50,6) NOT NULL, +c2 DOUBLE NOT NULL, +c3 BIGINT NOT NULL, +c4 BOOLEAN NOT NULL, +c5 DECIMAL(52,7) NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../core/tests/data/decimal_data.csv'; + +query TT +select arrow_typeof(c1), arrow_typeof(c5) from decimal256_simple limit 1; +---- +Decimal256(50, 6) Decimal256(52, 7) + +query R rowsort +SELECT c1 from decimal256_simple; +---- +0.00001 +0.00002 +0.00002 +0.00003 +0.00003 +0.00003 +0.00004 +0.00004 +0.00004 +0.00004 +0.00005 +0.00005 +0.00005 +0.00005 +0.00005 + +query R rowsort +select c1 from decimal256_simple where c1 > 0.000030; +---- +0.00004 +0.00004 +0.00004 +0.00004 +0.00005 +0.00005 +0.00005 +0.00005 +0.00005 + +query RRIBR rowsort +select * from decimal256_simple where c1 > c5; +---- +0.00002 0.000000000002 3 false 0.000019 +0.00003 0.000000000003 5 true 0.000011 +0.00005 0.000000000005 8 false 0.000033 + +query TR +select arrow_typeof(avg(c1)), avg(c1) from decimal256_simple; +---- +Decimal256(54, 10) 0.0000366666 + +query TR +select arrow_typeof(min(c1)), min(c1) from decimal256_simple where c4=false; +---- +Decimal256(50, 6) 0.00002 + +query TR +select arrow_typeof(max(c1)), max(c1) from decimal256_simple where c4=false; +---- +Decimal256(50, 6) 0.00005 + +query TR +select arrow_typeof(sum(c1)), sum(c1) from decimal256_simple; +---- +Decimal256(60, 6) 0.00055 + +query TR +select arrow_typeof(median(c1)), median(c1) from decimal256_simple; +---- +Decimal256(50, 6) 0.00004 + +query IR +select count(*),c1 from decimal256_simple group by c1 order by c1; +---- +1 0.00001 +2 0.00002 +3 0.00003 +4 0.00004 +5 0.00005 + +statement ok +drop table decimal256_simple; diff --git a/datafusion/sqllogictest/test_files/describe.slt b/datafusion/sqllogictest/test_files/describe.slt index 007aec443cbc..f94a2e453884 100644 --- a/datafusion/sqllogictest/test_files/describe.slt +++ b/datafusion/sqllogictest/test_files/describe.slt @@ -62,3 +62,27 @@ DROP TABLE aggregate_simple; statement error Error during planning: table 'datafusion.public.../core/tests/data/aggregate_simple.csv' not found DESCRIBE '../core/tests/data/aggregate_simple.csv'; + +########## +# Describe command +########## + +statement ok +CREATE EXTERNAL TABLE alltypes_tiny_pages STORED AS PARQUET LOCATION '../../parquet-testing/data/alltypes_tiny_pages.parquet'; + +query TTT +describe alltypes_tiny_pages; +---- +id Int32 YES +bool_col Boolean YES +tinyint_col Int8 YES +smallint_col Int16 YES +int_col Int32 YES +bigint_col Int64 YES +float_col Float32 YES +double_col Float64 YES +date_string_col Utf8 YES +string_col Utf8 YES +timestamp_col Timestamp(Nanosecond, None) YES +year Int32 YES +month Int32 YES diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt new file mode 100644 index 000000000000..002aade2528e --- /dev/null +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for querying on dictionary encoded data + +# Note: These tables model data as is common for timeseries, such as in InfluxDB IOx +# There are three types of columns: +# 1. tag columns, which are string dictionaries, often with low cardinality +# 2. field columns, which are typed, +# 3. a `time` columns, which is a nanosecond timestamp + +# It is common to group and filter on the "tag" columns (and thus on dictionary +# encoded values) + +# Table m1 with a tag column `tag_id` 4 fields `f1` - `f4`, and `time` + +statement ok +CREATE VIEW m1 AS +SELECT + arrow_cast(column1, 'Dictionary(Int32, Utf8)') as tag_id, + arrow_cast(column2, 'Float64') as f1, + arrow_cast(column3, 'Utf8') as f2, + arrow_cast(column4, 'Utf8') as f3, + arrow_cast(column5, 'Float64') as f4, + arrow_cast(column6, 'Timestamp(Nanosecond, None)') as time +FROM ( + VALUES + -- equivalent to the following line protocol data + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=1.0 1703030400000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=2.0 1703031000000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=3.0 1703031600000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=4.0 1703032200000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=5.0 1703032800000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=6.0 1703033400000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=7.0 1703034000000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=8.0 1703034600000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=9.0 1703035200000000000 + -- m1,tag_id=1000 f1=32,f2="foo",f3="True",f4=10.0 1703035800000000000 + ('1000', 32, 'foo', 'True', 1.0, 1703030400000000000), + ('1000', 32, 'foo', 'True', 2.0, 1703031000000000000), + ('1000', 32, 'foo', 'True', 3.0, 1703031600000000000), + ('1000', 32, 'foo', 'True', 4.0, 1703032200000000000), + ('1000', 32, 'foo', 'True', 5.0, 1703032800000000000), + ('1000', 32, 'foo', 'True', 6.0, 1703033400000000000), + ('1000', 32, 'foo', 'True', 7.0, 1703034000000000000), + ('1000', 32, 'foo', 'True', 8.0, 1703034600000000000), + ('1000', 32, 'foo', 'True', 9.0, 1703035200000000000), + ('1000', 32, 'foo', 'True', 10.0, 1703035800000000000) +); + +query ?RTTRP +SELECT * FROM m1; +---- +1000 32 foo True 1 2023-12-20T00:00:00 +1000 32 foo True 2 2023-12-20T00:10:00 +1000 32 foo True 3 2023-12-20T00:20:00 +1000 32 foo True 4 2023-12-20T00:30:00 +1000 32 foo True 5 2023-12-20T00:40:00 +1000 32 foo True 6 2023-12-20T00:50:00 +1000 32 foo True 7 2023-12-20T01:00:00 +1000 32 foo True 8 2023-12-20T01:10:00 +1000 32 foo True 9 2023-12-20T01:20:00 +1000 32 foo True 10 2023-12-20T01:30:00 + +# Note that te type of the tag column is `Dictionary(Int32, Utf8)` +query TTT +DESCRIBE m1; +---- +tag_id Dictionary(Int32, Utf8) YES +f1 Float64 YES +f2 Utf8 YES +f3 Utf8 YES +f4 Float64 YES +time Timestamp(Nanosecond, None) YES + + +# Table m2 with a tag columns `tag_id` and `type`, a field column `f5`, and `time` +statement ok +CREATE VIEW m2 AS +SELECT + arrow_cast(column1, 'Dictionary(Int32, Utf8)') as type, + arrow_cast(column2, 'Dictionary(Int32, Utf8)') as tag_id, + arrow_cast(column3, 'Float64') as f5, + arrow_cast(column4, 'Timestamp(Nanosecond, None)') as time +FROM ( + VALUES + -- equivalent to the following line protocol data + -- m2,type=active,tag_id=1000 f5=100 1701648000000000000 + -- m2,type=active,tag_id=1000 f5=200 1701648600000000000 + -- m2,type=active,tag_id=1000 f5=300 1701649200000000000 + -- m2,type=active,tag_id=1000 f5=400 1701649800000000000 + -- m2,type=active,tag_id=1000 f5=500 1701650400000000000 + -- m2,type=active,tag_id=1000 f5=600 1701651000000000000 + -- m2,type=passive,tag_id=2000 f5=700 1701651600000000000 + -- m2,type=passive,tag_id=1000 f5=800 1701652200000000000 + -- m2,type=passive,tag_id=1000 f5=900 1701652800000000000 + -- m2,type=passive,tag_id=1000 f5=1000 1701653400000000000 + ('active', '1000', 100, 1701648000000000000), + ('active', '1000', 200, 1701648600000000000), + ('active', '1000', 300, 1701649200000000000), + ('active', '1000', 400, 1701649800000000000), + ('active', '1000', 500, 1701650400000000000), + ('active', '1000', 600, 1701651000000000000), + ('passive', '1000', 700, 1701651600000000000), + ('passive', '1000', 800, 1701652200000000000), + ('passive', '1000', 900, 1701652800000000000), + ('passive', '1000', 1000, 1701653400000000000) +); + +query ??RP +SELECT * FROM m2; +---- +active 1000 100 2023-12-04T00:00:00 +active 1000 200 2023-12-04T00:10:00 +active 1000 300 2023-12-04T00:20:00 +active 1000 400 2023-12-04T00:30:00 +active 1000 500 2023-12-04T00:40:00 +active 1000 600 2023-12-04T00:50:00 +passive 1000 700 2023-12-04T01:00:00 +passive 1000 800 2023-12-04T01:10:00 +passive 1000 900 2023-12-04T01:20:00 +passive 1000 1000 2023-12-04T01:30:00 + +query TTT +DESCRIBE m2; +---- +type Dictionary(Int32, Utf8) YES +tag_id Dictionary(Int32, Utf8) YES +f5 Float64 YES +time Timestamp(Nanosecond, None) YES + +query I +select count(*) from m1 where tag_id = '1000' and time < '2024-01-03T14:46:35+01:00'; +---- +10 + +query RRR rowsort +select min(f5), max(f5), avg(f5) from m2 where tag_id = '1000' and time < '2024-01-03T14:46:35+01:00' group by type; +---- +100 600 350 +700 1000 850 + +query IRRRP +select count(*), min(f5), max(f5), avg(f5), date_bin('30 minutes', time) as "time" +from m2 where tag_id = '1000' and time < '2024-01-03T14:46:35+01:00' +group by date_bin('30 minutes', time) +order by date_bin('30 minutes', time) DESC +---- +1 1000 1000 1000 2023-12-04T01:30:00 +3 700 900 800 2023-12-04T01:00:00 +3 400 600 500 2023-12-04T00:30:00 +3 100 300 200 2023-12-04T00:00:00 + + + +# Reproducer for https://github.com/apache/arrow-datafusion/issues/8738 +# This query should work correctly +query P?TT rowsort +SELECT + "data"."timestamp" as "time", + "data"."tag_id", + "data"."field", + "data"."value" +FROM ( + ( + SELECT "m2"."time" as "timestamp", "m2"."tag_id", 'active_power' as "field", "m2"."f5" as "value" + FROM "m2" + WHERE "m2"."time" >= '2023-12-05T14:46:35+01:00' AND "m2"."time" < '2024-01-03T14:46:35+01:00' + AND "m2"."f5" IS NOT NULL + AND "m2"."type" IN ('active') + AND "m2"."tag_id" IN ('1000') + ) UNION ( + SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f1' as "field", "m1"."f1" as "value" + FROM "m1" + WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00' + AND "m1"."f1" IS NOT NULL + AND "m1"."tag_id" IN ('1000') + ) UNION ( + SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f2' as "field", "m1"."f2" as "value" + FROM "m1" + WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00' + AND "m1"."f2" IS NOT NULL + AND "m1"."tag_id" IN ('1000') + ) +) as "data" +ORDER BY + "time", + "data"."tag_id" +; +---- +2023-12-20T00:00:00 1000 f1 32.0 +2023-12-20T00:00:00 1000 f2 foo +2023-12-20T00:10:00 1000 f1 32.0 +2023-12-20T00:10:00 1000 f2 foo +2023-12-20T00:20:00 1000 f1 32.0 +2023-12-20T00:20:00 1000 f2 foo +2023-12-20T00:30:00 1000 f1 32.0 +2023-12-20T00:30:00 1000 f2 foo +2023-12-20T00:40:00 1000 f1 32.0 +2023-12-20T00:40:00 1000 f2 foo +2023-12-20T00:50:00 1000 f1 32.0 +2023-12-20T00:50:00 1000 f2 foo +2023-12-20T01:00:00 1000 f1 32.0 +2023-12-20T01:00:00 1000 f2 foo +2023-12-20T01:10:00 1000 f1 32.0 +2023-12-20T01:10:00 1000 f2 foo +2023-12-20T01:20:00 1000 f1 32.0 +2023-12-20T01:20:00 1000 f2 foo +2023-12-20T01:30:00 1000 f1 32.0 +2023-12-20T01:30:00 1000 f2 foo + + +# deterministic sort (so we can avoid rowsort) +query P?TT +SELECT + "data"."timestamp" as "time", + "data"."tag_id", + "data"."field", + "data"."value" +FROM ( + ( + SELECT "m2"."time" as "timestamp", "m2"."tag_id", 'active_power' as "field", "m2"."f5" as "value" + FROM "m2" + WHERE "m2"."time" >= '2023-12-05T14:46:35+01:00' AND "m2"."time" < '2024-01-03T14:46:35+01:00' + AND "m2"."f5" IS NOT NULL + AND "m2"."type" IN ('active') + AND "m2"."tag_id" IN ('1000') + ) UNION ( + SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f1' as "field", "m1"."f1" as "value" + FROM "m1" + WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00' + AND "m1"."f1" IS NOT NULL + AND "m1"."tag_id" IN ('1000') + ) UNION ( + SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f2' as "field", "m1"."f2" as "value" + FROM "m1" + WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00' + AND "m1"."f2" IS NOT NULL + AND "m1"."tag_id" IN ('1000') + ) +) as "data" +ORDER BY + "time", + "data"."tag_id", + "data"."field", + "data"."value" +; +---- +2023-12-20T00:00:00 1000 f1 32.0 +2023-12-20T00:00:00 1000 f2 foo +2023-12-20T00:10:00 1000 f1 32.0 +2023-12-20T00:10:00 1000 f2 foo +2023-12-20T00:20:00 1000 f1 32.0 +2023-12-20T00:20:00 1000 f2 foo +2023-12-20T00:30:00 1000 f1 32.0 +2023-12-20T00:30:00 1000 f2 foo +2023-12-20T00:40:00 1000 f1 32.0 +2023-12-20T00:40:00 1000 f2 foo +2023-12-20T00:50:00 1000 f1 32.0 +2023-12-20T00:50:00 1000 f2 foo +2023-12-20T01:00:00 1000 f1 32.0 +2023-12-20T01:00:00 1000 f2 foo +2023-12-20T01:10:00 1000 f1 32.0 +2023-12-20T01:10:00 1000 f2 foo +2023-12-20T01:20:00 1000 f1 32.0 +2023-12-20T01:20:00 1000 f2 foo +2023-12-20T01:30:00 1000 f1 32.0 +2023-12-20T01:30:00 1000 f2 foo diff --git a/datafusion/sqllogictest/test_files/distinct_on.slt b/datafusion/sqllogictest/test_files/distinct_on.slt new file mode 100644 index 000000000000..3f609e254839 --- /dev/null +++ b/datafusion/sqllogictest/test_files/distinct_on.slt @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +# Basic example: distinct on the first column project the second one, and +# order by the third +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1, c3, c9; +---- +a 4 +b 4 +c 2 +d 1 +e 3 + +# Basic example + reverse order of the selected column +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1, c3 DESC, c9; +---- +a 1 +b 5 +c 4 +d 1 +e 1 + +# Basic example + reverse order of the ON column +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1 DESC, c3, c9; +---- +e 3 +d 1 +c 2 +b 4 +a 4 + +# Basic example + reverse order of both columns + limit +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1 DESC, c3 DESC LIMIT 3; +---- +e 1 +d 1 +c 4 + +# Basic example + omit ON column from selection +query I +SELECT DISTINCT ON (c1) c2 FROM aggregate_test_100 ORDER BY c1, c3; +---- +4 +4 +2 +1 +3 + +# Test explain makes sense +query TT +EXPLAIN SELECT DISTINCT ON (c1) c3, c2 FROM aggregate_test_100 ORDER BY c1, c3; +---- +logical_plan +Projection: FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST] AS c3, FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST] AS c2 +--Sort: aggregate_test_100.c1 ASC NULLS LAST +----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST], FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]]] +------TableScan: aggregate_test_100 projection=[c1, c2, c3] +physical_plan +ProjectionExec: expr=[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]@1 as c3, FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]@2 as c2] +--SortPreservingMergeExec: [c1@0 ASC NULLS LAST] +----SortExec: expr=[c1@0 ASC NULLS LAST] +------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true + +# ON expressions are not a sub-set of the ORDER BY expressions +query error SELECT DISTINCT ON expressions must match initial ORDER BY expressions +SELECT DISTINCT ON (c2 % 2 = 0) c2, c3 - 100 FROM aggregate_test_100 ORDER BY c2, c3; + +# ON expressions are empty +query error DataFusion error: Error during planning: No `ON` expressions provided +SELECT DISTINCT ON () c1, c2 FROM aggregate_test_100 ORDER BY c1, c2; + +# Use expressions in the ON and ORDER BY clauses, as well as the selection +query II +SELECT DISTINCT ON (c2 % 2 = 0) c2, c3 - 100 FROM aggregate_test_100 ORDER BY c2 % 2 = 0, c3 DESC; +---- +1 25 +4 23 + +# Multiple complex expressions +query TIB +SELECT DISTINCT ON (chr(ascii(c1) + 3), c2 % 2) chr(ascii(upper(c1)) + 3), c2 % 2, c3 > 80 AND c2 % 2 = 1 +FROM aggregate_test_100 +WHERE c1 IN ('a', 'b') +ORDER BY chr(ascii(c1) + 3), c2 % 2, c3 DESC; +---- +D 0 false +D 1 true +E 0 false +E 1 false + +# Joins using CTEs +query II +WITH t1 AS (SELECT * FROM aggregate_test_100), +t2 AS (SELECT * FROM aggregate_test_100) +SELECT DISTINCT ON (t1.c1, t2.c2) t2.c3, t1.c4 +FROM t1 INNER JOIN t2 ON t1.c13 = t2.c13 +ORDER BY t1.c1, t2.c2, t2.c5 +LIMIT 3; +---- +-25 15295 +45 15673 +-72 -11122 diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index 1380ac2f2bfd..e3b2610e51be 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -130,3 +130,7 @@ c9, nth_value(c5, 2, 3) over (order by c9) as nv1 from aggregate_test_100 order by c9 + + +statement error Inconsistent data type across values list at row 1 column 0. Was Int64 but found Utf8 +create table foo as values (1), ('foo'); diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index b1ba1eb36d11..2a39e3138869 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -93,8 +93,8 @@ query TT EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c3) ---- physical_plan -ProjectionExec: expr=[2 as COUNT(UInt8(1))] ---EmptyExec: produce_one_row=true +ProjectionExec: expr=[2 as COUNT(*)] +--PlaceholderRowExec statement ok set datafusion.explain.physical_plan_only = false @@ -140,7 +140,7 @@ physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/te # create a sink table, path is same with aggregate_test_100 table # we do not overwrite this file, we only assert plan. statement ok -CREATE EXTERNAL TABLE sink_table ( +CREATE UNBOUNDED EXTERNAL TABLE sink_table ( c1 VARCHAR NOT NULL, c2 TINYINT NOT NULL, c3 SMALLINT NOT NULL, @@ -168,10 +168,9 @@ Dml: op=[Insert Into] table=[sink_table] ----Sort: aggregate_test_100.c1 ASC NULLS LAST ------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] physical_plan -InsertExec: sink=CsvSink(writer_mode=Append, file_groups=[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]) ---ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c5@4 as c5, c6@5 as c6, c7@6 as c7, c8@7 as c8, c9@8 as c9, c10@9 as c10, c11@10 as c11, c12@11 as c12, c13@12 as c13] -----SortExec: expr=[c1@0 ASC NULLS LAST] -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true +FileSinkExec: sink=StreamWrite { location: "../../testing/data/csv/aggregate_test_100.csv", batch_size: 8192, encoding: Csv, header: true, .. } +--SortExec: expr=[c1@0 ASC NULLS LAST] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true # test EXPLAIN VERBOSE query TT @@ -181,9 +180,11 @@ initial_logical_plan Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c --TableScan: simple_explain_test logical_plan after inline_table_scan SAME TEXT AS ABOVE +logical_plan after operator_to_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE logical_plan after count_wildcard_rule SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE +logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -192,7 +193,6 @@ logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after merge_projection SAME TEXT AS ABOVE logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE @@ -200,6 +200,7 @@ logical_plan after eliminate_cross_join SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after eliminate_limit SAME TEXT AS ABOVE logical_plan after propagate_empty_relation SAME TEXT AS ABOVE +logical_plan after eliminate_one_union SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE @@ -208,11 +209,8 @@ logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE -logical_plan after push_down_projection -Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c ---TableScan: simple_explain_test projection=[a, b, c] -logical_plan after eliminate_projection TableScan: simple_explain_test projection=[a, b, c] -logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] +logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -221,7 +219,6 @@ logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after merge_projection SAME TEXT AS ABOVE logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE @@ -229,6 +226,7 @@ logical_plan after eliminate_cross_join SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after eliminate_limit SAME TEXT AS ABOVE logical_plan after propagate_empty_relation SAME TEXT AS ABOVE +logical_plan after eliminate_one_union SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE @@ -237,20 +235,26 @@ logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE -logical_plan after push_down_projection SAME TEXT AS ABOVE -logical_plan after eliminate_projection SAME TEXT AS ABOVE -logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after optimize_projections SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c] initial_physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true +initial_physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] +physical_plan after OutputRequirements +OutputRequirementExec +--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true physical_plan after aggregate_statistics SAME TEXT AS ABOVE physical_plan after join_selection SAME TEXT AS ABOVE +physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after OutputRequirements CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true +physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] ### tests for EXPLAIN with display statistics enabled @@ -265,8 +269,8 @@ query TT EXPLAIN SELECT a, b, c FROM simple_explain_test limit 10; ---- physical_plan -GlobalLimitExec: skip=0, fetch=10, statistics=[rows=10, bytes=None, exact=false] ---CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], limit=10, has_header=true, statistics=[] +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Inexact(10), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] +--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], limit=10, has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] # Parquet scan with statistics collected statement ok @@ -279,11 +283,100 @@ query TT EXPLAIN SELECT * FROM alltypes_plain limit 10; ---- physical_plan -GlobalLimitExec: skip=0, fetch=10, statistics=[rows=8, bytes=None, exact=true] ---ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[rows=8, bytes=None, exact=true] +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] + +# explain verbose with both collect & show statistics on +query TT +EXPLAIN VERBOSE SELECT * FROM alltypes_plain limit 10; +---- +initial_physical_plan +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after OutputRequirements +OutputRequirementExec, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +----ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after aggregate_statistics SAME TEXT AS ABOVE +physical_plan after join_selection SAME TEXT AS ABOVE +physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after EnforceDistribution SAME TEXT AS ABOVE +physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE +physical_plan after EnforceSorting SAME TEXT AS ABOVE +physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after OutputRequirements +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after PipelineChecker SAME TEXT AS ABOVE +physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] + + +statement ok +set datafusion.explain.show_statistics = false; + +# explain verbose with collect on and & show statistics off: still has stats +query TT +EXPLAIN VERBOSE SELECT * FROM alltypes_plain limit 10; +---- +initial_physical_plan +GlobalLimitExec: skip=0, fetch=10 +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +initial_physical_plan_with_stats +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after OutputRequirements +OutputRequirementExec +--GlobalLimitExec: skip=0, fetch=10 +----ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +physical_plan after aggregate_statistics SAME TEXT AS ABOVE +physical_plan after join_selection SAME TEXT AS ABOVE +physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after EnforceDistribution SAME TEXT AS ABOVE +physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE +physical_plan after EnforceSorting SAME TEXT AS ABOVE +physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after OutputRequirements +GlobalLimitExec: skip=0, fetch=10 +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +physical_plan after PipelineChecker SAME TEXT AS ABOVE +physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan +GlobalLimitExec: skip=0, fetch=10 +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +physical_plan_with_stats +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] + statement ok set datafusion.execution.collect_statistics = false; +# Explain ArrayFuncions + statement ok -set datafusion.explain.show_statistics = false; +set datafusion.explain.physical_plan_only = false + +query TT +explain select make_array(make_array(1, 2, 3), make_array(4, 5, 6)); +---- +logical_plan +Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6))) +--EmptyRelation +physical_plan +ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] +--PlaceholderRowExec + +query TT +explain select [[1, 2, 3], [4, 5, 6]]; +---- +logical_plan +Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6))) +--EmptyRelation +physical_plan +ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] +--PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt new file mode 100644 index 000000000000..a2a8d9c6475c --- /dev/null +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -0,0 +1,1251 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# test_boolean_expressions +query BBBB +SELECT true, false, false = false, true = false +---- +true false true false + +# test_mathematical_expressions_with_null +query RRRRRRRRRRRRRRRRRR?RRRRRRRIRRRRRRBB +SELECT + sqrt(NULL), + cbrt(NULL), + sin(NULL), + cos(NULL), + tan(NULL), + asin(NULL), + acos(NULL), + atan(NULL), + sinh(NULL), + cosh(NULL), + tanh(NULL), + asinh(NULL), + acosh(NULL), + atanh(NULL), + floor(NULL), + ceil(NULL), + round(NULL), + trunc(NULL), + abs(NULL), + signum(NULL), + exp(NULL), + ln(NULL), + log2(NULL), + log10(NULL), + power(NULL, 2), + power(NULL, NULL), + power(2, NULL), + atan2(NULL, NULL), + atan2(1, NULL), + atan2(NULL, 1), + nanvl(NULL, NULL), + nanvl(1, NULL), + nanvl(NULL, 1), + isnan(NULL), + iszero(NULL) +---- +NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL + +# test_array_cast_invalid_timezone_will_panic +statement error Parser error: Invalid timezone "Foo": 'Foo' is not a valid timezone +SELECT arrow_cast('2021-01-02T03:04:00', 'Timestamp(Nanosecond, Some("Foo"))') + +# test_array_index +query III??IIIIII +SELECT + ([5,4,3,2,1])[1], + ([5,4,3,2,1])[2], + ([5,4,3,2,1])[5], + ([[1, 2], [2, 3], [3,4]])[1], + ([[1, 2], [2, 3], [3,4]])[3], + ([[1, 2], [2, 3], [3,4]])[1][1], + ([[1, 2], [2, 3], [3,4]])[2][2], + ([[1, 2], [2, 3], [3,4]])[3][2], + -- out of bounds + ([5,4,3,2,1])[0], + ([5,4,3,2,1])[6], + -- ([5,4,3,2,1])[-1], -- TODO: wrong answer + -- ([5,4,3,2,1])[null], -- TODO: not supported + ([5,4,3,2,1])[100] +---- +5 4 1 [1, 2] [3, 4] 1 3 4 NULL NULL NULL + +# test_array_literals +query ????? +SELECT + [1,2,3,4,5], + [true, false], + ['str1', 'str2'], + [[1,2], [3,4]], + [] +---- +[1, 2, 3, 4, 5] [true, false] [str1, str2] [[1, 2], [3, 4]] [] + +# test_struct_literals +query ?????? +SELECT + STRUCT(1,2,3,4,5), + STRUCT(Null), + STRUCT(2), + STRUCT('1',Null), + STRUCT(true, false), + STRUCT('str1', 'str2') +---- +{c0: 1, c1: 2, c2: 3, c3: 4, c4: 5} {c0: } {c0: 2} {c0: 1, c1: } {c0: true, c1: false} {c0: str1, c1: str2} + +# test binary_bitwise_shift +query IIII +SELECT + 2 << 10, + 2048 >> 10, + 2048 << NULL, + 2048 >> NULL +---- +2048 2 NULL NULL + +query ? +SELECT interval '1' +---- +0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs + +query ? +SELECT interval '1 second' +---- +0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs + +query ? +SELECT interval '500 milliseconds' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.500000000 secs + +query ? +SELECT interval '5 second' +---- +0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs + +query ? +SELECT interval '0.5 minute' +---- +0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs + +query ? +SELECT interval '.5 minute' +---- +0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs + +query ? +SELECT interval '5 minute' +---- +0 years 0 mons 0 days 0 hours 5 mins 0.000000000 secs + +query ? +SELECT interval '5 minute 1 second' +---- +0 years 0 mons 0 days 0 hours 5 mins 1.000000000 secs + +query ? +SELECT interval '1 hour' +---- +0 years 0 mons 0 days 1 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '5 hour' +---- +0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 day' +---- +0 years 0 mons 1 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 week' +---- +0 years 0 mons 7 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '2 weeks' +---- +0 years 0 mons 14 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 day 1' +---- +0 years 0 mons 1 days 0 hours 0 mins 1.000000000 secs + +query ? +SELECT interval '0.5' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.500000000 secs + +query ? +SELECT interval '0.5 day 1' +---- +0 years 0 mons 0 days 12 hours 0 mins 1.000000000 secs + +query ? +SELECT interval '0.49 day' +---- +0 years 0 mons 0 days 11 hours 45 mins 36.000000000 secs + +query ? +SELECT interval '0.499 day' +---- +0 years 0 mons 0 days 11 hours 58 mins 33.600000000 secs + +query ? +SELECT interval '0.4999 day' +---- +0 years 0 mons 0 days 11 hours 59 mins 51.360000000 secs + +query ? +SELECT interval '0.49999 day' +---- +0 years 0 mons 0 days 11 hours 59 mins 59.136000000 secs + +query ? +SELECT interval '0.49999999999 day' +---- +0 years 0 mons 0 days 11 hours 59 mins 59.999999136 secs + +query ? +SELECT interval '5 day' +---- +0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs + +# Hour is ignored, this matches PostgreSQL +query ? +SELECT interval '5 day' hour +---- +0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds' +---- +0 years 0 mons 5 days 4 hours 3 mins 2.100000000 secs + +query ? +SELECT interval '0.5 month' +---- +0 years 0 mons 15 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '0.5' month +---- +0 years 0 mons 15 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 month' +---- +0 years 1 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1' MONTH +---- +0 years 1 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '5 month' +---- +0 years 5 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '13 month' +---- +0 years 13 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '0.5 year' +---- +0 years 6 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 year' +---- +0 years 12 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 decade' +---- +0 years 120 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '2 decades' +---- +0 years 240 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 century' +---- +0 years 1200 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '2 year' +---- +0 years 24 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 year 1 day' +---- +0 years 12 mons 1 days 0 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 year 1 day 1 hour' +---- +0 years 12 mons 1 days 1 hours 0 mins 0.000000000 secs + +query ? +SELECT interval '1 year 1 day 1 hour 1 minute' +---- +0 years 12 mons 1 days 1 hours 1 mins 0.000000000 secs + +query ? +SELECT interval '1 year 1 day 1 hour 1 minute 1 second' +---- +0 years 12 mons 1 days 1 hours 1 mins 1.000000000 secs + +query I +SELECT ascii('') +---- +0 + +query I +SELECT ascii('x') +---- +120 + +query I +SELECT ascii(NULL) +---- +NULL + +query I +SELECT bit_length('') +---- +0 + +query I +SELECT bit_length('chars') +---- +40 + +query I +SELECT bit_length('josé') +---- +40 + +query ? +SELECT bit_length(NULL) +---- +NULL + +query T +SELECT btrim(' xyxtrimyyx ', NULL) +---- +NULL + +query T +SELECT btrim(' xyxtrimyyx ') +---- +xyxtrimyyx + +query T +SELECT btrim('\n xyxtrimyyx \n') +---- +\n xyxtrimyyx \n + +query T +SELECT btrim('xyxtrimyyx', 'xyz') +---- +trim + +query T +SELECT btrim('\nxyxtrimyyx\n', 'xyz\n') +---- +trim + +query ? +SELECT btrim(NULL, 'xyz') +---- +NULL + +query T +SELECT chr(CAST(120 AS int)) +---- +x + +query T +SELECT chr(CAST(128175 AS int)) +---- +💯 + +query T +SELECT chr(CAST(NULL AS int)) +---- +NULL + +query T +SELECT concat('a','b','c') +---- +abc + +query T +SELECT concat('abcde', 2, NULL, 22) +---- +abcde222 + +query T +SELECT concat(NULL) +---- +(empty) + +query T +SELECT concat_ws(',', 'abcde', 2, NULL, 22) +---- +abcde,2,22 + +query T +SELECT concat_ws('|','a','b','c') +---- +a|b|c + +query T +SELECT concat_ws('|',NULL) +---- +(empty) + +query T +SELECT concat_ws(NULL,'a',NULL,'b','c') +---- +NULL + +query T +SELECT concat_ws('|','a',NULL) +---- +a + +query T +SELECT concat_ws('|','a',NULL,NULL) +---- +a + +query T +SELECT initcap('') +---- +(empty) + +query T +SELECT initcap('hi THOMAS') +---- +Hi Thomas + +query ? +SELECT initcap(NULL) +---- +NULL + +query T +SELECT lower('') +---- +(empty) + +query T +SELECT lower('TOM') +---- +tom + +query ? +SELECT lower(NULL) +---- +NULL + +query T +SELECT ltrim(' zzzytest ', NULL) +---- +NULL + +query T +SELECT ltrim(' zzzytest ') +---- +zzzytest + +query T +SELECT ltrim('zzzytest', 'xyz') +---- +test + +query ? +SELECT ltrim(NULL, 'xyz') +---- +NULL + +query I +SELECT octet_length('') +---- +0 + +query I +SELECT octet_length('chars') +---- +5 + +query I +SELECT octet_length('josé') +---- +5 + +query ? +SELECT octet_length(NULL) +---- +NULL + +query T +SELECT repeat('Pg', 4) +---- +PgPgPgPg + +query T +SELECT repeat('Pg', CAST(NULL AS INT)) +---- +NULL + +query ? +SELECT repeat(NULL, 4) +---- +NULL + +query T +SELECT replace('abcdefabcdef', 'cd', 'XX') +---- +abXXefabXXef + +query T +SELECT replace('abcdefabcdef', 'cd', NULL) +---- +NULL + +query T +SELECT replace('abcdefabcdef', 'notmatch', 'XX') +---- +abcdefabcdef + +query T +SELECT replace('abcdefabcdef', NULL, 'XX') +---- +NULL + +query ? +SELECT replace(NULL, 'cd', 'XX') +---- +NULL + +query T +SELECT rtrim(' testxxzx ') +---- + testxxzx + +query T +SELECT rtrim(' zzzytest ', NULL) +---- +NULL + +query T +SELECT rtrim('testxxzx', 'xyz') +---- +test + +query ? +SELECT rtrim(NULL, 'xyz') +---- +NULL + +query T +SELECT split_part('abc~@~def~@~ghi', '~@~', 2) +---- +def + +query T +SELECT split_part('abc~@~def~@~ghi', '~@~', 20) +---- +(empty) + +query ? +SELECT split_part(NULL, '~@~', 20) +---- +NULL + +query T +SELECT split_part('abc~@~def~@~ghi', NULL, 20) +---- +NULL + +query T +SELECT split_part('abc~@~def~@~ghi', '~@~', CAST(NULL AS INT)) +---- +NULL + +query B +SELECT starts_with('alphabet', 'alph') +---- +true + +query B +SELECT starts_with('alphabet', 'blph') +---- +false + +query B +SELECT starts_with(NULL, 'blph') +---- +NULL + +query B +SELECT starts_with('alphabet', NULL) +---- +NULL + +query T +SELECT to_hex(2147483647) +---- +7fffffff + +query T +SELECT to_hex(9223372036854775807) +---- +7fffffffffffffff + +query T +SELECT to_hex(CAST(NULL AS int)) +---- +NULL + +query T +SELECT trim(' tom ') +---- +tom + +query T +SELECT trim(LEADING ' tom ') +---- +tom + +query T +SELECT trim(TRAILING ' tom ') +---- + tom + +query T +SELECT trim(BOTH ' tom ') +---- +tom + +query T +SELECT trim(LEADING ' ' FROM ' tom ') +---- +tom + +query T +SELECT trim(TRAILING ' ' FROM ' tom ') +---- + tom + +query T +SELECT trim(BOTH ' ' FROM ' tom ') +---- +tom + +query T +SELECT trim(' ' FROM ' tom ') +---- +tom + +query T +SELECT trim(LEADING 'x' FROM 'xxxtomxxx') +---- +tomxxx + +query T +SELECT trim(TRAILING 'x' FROM 'xxxtomxxx') +---- +xxxtom + +query T +SELECT trim(BOTH 'x' FROM 'xxxtomxx') +---- +tom + +query T +SELECT trim('x' FROM 'xxxtomxx') +---- +tom + + +query T +SELECT trim(LEADING 'xy' FROM 'xyxabcxyzdefxyx') +---- +abcxyzdefxyx + +query T +SELECT trim(TRAILING 'xy' FROM 'xyxabcxyzdefxyx') +---- +xyxabcxyzdef + +query T +SELECT trim(BOTH 'xy' FROM 'xyxabcxyzdefxyx') +---- +abcxyzdef + +query T +SELECT trim('xy' FROM 'xyxabcxyzdefxyx') +---- +abcxyzdef + +query T +SELECT trim(' tom') +---- +tom + +query T +SELECT trim('') +---- +(empty) + +query T +SELECT trim('tom ') +---- +tom + +query T +SELECT upper('') +---- +(empty) + +query T +SELECT upper('tom') +---- +TOM + +query ? +SELECT upper(NULL) +---- +NULL + +# TODO issue: https://github.com/apache/arrow-datafusion/issues/6596 +# query ?? +#SELECT +# CAST([1,2,3,4] AS INT[]) as a, +# CAST([1,2,3,4] AS NUMERIC(10,4)[]) as b +#---- +#[1, 2, 3, 4] [1.0000, 2.0000, 3.0000, 4.0000] + +# test_random_expression +query BB +SELECT + random() BETWEEN 0.0 AND 1.0, + random() = random() +---- +true false + +# test_uuid_expression +query II +SELECT octet_length(uuid()), length(uuid()) +---- +36 36 + +# test_cast_expressions +query IIII +SELECT + CAST('0' AS INT) as a, + CAST(NULL AS INT) as b, + TRY_CAST('0' AS INT) as c, + TRY_CAST('x' AS INT) as d +---- +0 NULL 0 NULL + +# test_extract_date_part + +query R +SELECT date_part('YEAR', CAST('2000-01-01' AS DATE)) +---- +2000 + +query R +SELECT EXTRACT(year FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query R +SELECT date_part('QUARTER', CAST('2000-01-01' AS DATE)) +---- +1 + +query R +SELECT EXTRACT(quarter FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +3 + +query R +SELECT date_part('MONTH', CAST('2000-01-01' AS DATE)) +---- +1 + +query R +SELECT EXTRACT(month FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +9 + +query R +SELECT date_part('WEEK', CAST('2003-01-01' AS DATE)) +---- +1 + +query R +SELECT EXTRACT(WEEK FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +37 + +query R +SELECT date_part('DAY', CAST('2000-01-01' AS DATE)) +---- +1 + +query R +SELECT EXTRACT(day FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +8 + +query R +SELECT date_part('DOY', CAST('2000-01-01' AS DATE)) +---- +1 + +query R +SELECT EXTRACT(doy FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +252 + +query R +SELECT date_part('DOW', CAST('2000-01-01' AS DATE)) +---- +6 + +query R +SELECT EXTRACT(dow FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +2 + +query R +SELECT date_part('HOUR', CAST('2000-01-01' AS DATE)) +---- +0 + +query R +SELECT EXTRACT(hour FROM to_timestamp('2020-09-08T12:03:03+00:00')) +---- +12 + +query R +SELECT EXTRACT(minute FROM to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query R +SELECT date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query R +SELECT EXTRACT(second FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12.12345678 + +query R +SELECT EXTRACT(millisecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123.45678 + +query R +SELECT EXTRACT(microsecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456.78 + +query R +SELECT EXTRACT(nanosecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456780 + +# Keep precision when coercing Utf8 to Timestamp +query R +SELECT date_part('second', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12.12345678 + +query R +SELECT date_part('millisecond', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123.45678 + +query R +SELECT date_part('microsecond', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456.78 + +query R +SELECT date_part('nanosecond', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456780 + +query R +SELECT date_part('second', '2020-09-08T12:00:12.12345678+00:00') +---- +12.12345678 + +query R +SELECT date_part('millisecond', '2020-09-08T12:00:12.12345678+00:00') +---- +12123.45678 + +query R +SELECT date_part('microsecond', '2020-09-08T12:00:12.12345678+00:00') +---- +12123456.78 + +query R +SELECT date_part('nanosecond', '2020-09-08T12:00:12.12345678+00:00') +---- +12123456780 + +# test_extract_epoch + +query R +SELECT extract(epoch from '1870-01-01T07:29:10.256'::timestamp) +---- +-3155646649.744 + +query R +SELECT extract(epoch from '2000-01-01T00:00:00.000'::timestamp) +---- +946684800 + +query R +SELECT extract(epoch from to_timestamp('2000-01-01T00:00:00+00:00')) +---- +946684800 + +query R +SELECT extract(epoch from NULL::timestamp) +---- +NULL + +query R +SELECT extract(epoch from arrow_cast('1970-01-01', 'Date32')) +---- +0 + +query R +SELECT extract(epoch from arrow_cast('1970-01-02', 'Date32')) +---- +86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-11', 'Date32')) +---- +864000 + +query R +SELECT extract(epoch from arrow_cast('1969-12-31', 'Date32')) +---- +-86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-01', 'Date64')) +---- +0 + +query R +SELECT extract(epoch from arrow_cast('1970-01-02', 'Date64')) +---- +86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-11', 'Date64')) +---- +864000 + +query R +SELECT extract(epoch from arrow_cast('1969-12-31', 'Date64')) +---- +-86400 + +# test_extract_date_part_func + +query B +SELECT (date_part('year', now()) = EXTRACT(year FROM now())) +---- +true + +query B +SELECT (date_part('quarter', now()) = EXTRACT(quarter FROM now())) +---- +true + +query B +SELECT (date_part('month', now()) = EXTRACT(month FROM now())) +---- +true + +query B +SELECT (date_part('week', now()) = EXTRACT(week FROM now())) +---- +true + +query B +SELECT (date_part('day', now()) = EXTRACT(day FROM now())) +---- +true + +query B +SELECT (date_part('hour', now()) = EXTRACT(hour FROM now())) +---- +true + +query B +SELECT (date_part('minute', now()) = EXTRACT(minute FROM now())) +---- +true + +query B +SELECT (date_part('second', now()) = EXTRACT(second FROM now())) +---- +true + +query B +SELECT (date_part('millisecond', now()) = EXTRACT(millisecond FROM now())) +---- +true + +query B +SELECT (date_part('microsecond', now()) = EXTRACT(microsecond FROM now())) +---- +true + +query B +SELECT (date_part('nanosecond', now()) = EXTRACT(nanosecond FROM now())) +---- +true + +query B +SELECT 'a' IN ('a','b') +---- +true + +query B +SELECT 'c' IN ('a','b') +---- +false + +query B +SELECT 'c' NOT IN ('a','b') +---- +true + +query B +SELECT 'a' NOT IN ('a','b') +---- +false + +query B +SELECT NULL IN ('a','b') +---- +NULL + +query B +SELECT NULL NOT IN ('a','b') +---- +NULL + +query B +SELECT 'a' IN ('a','b',NULL) +---- +true + +query B +SELECT 'c' IN ('a','b',NULL) +---- +NULL + +query B +SELECT 'a' NOT IN ('a','b',NULL) +---- +false + +query B +SELECT 'c' NOT IN ('a','b',NULL) +---- +NULL + +query B +SELECT 0 IN (0,1,2) +---- +true + +query B +SELECT 3 IN (0,1,2) +---- +false + +query B +SELECT 3 NOT IN (0,1,2) +---- +true + +query B +SELECT 0 NOT IN (0,1,2) +---- +false + +query B +SELECT NULL IN (0,1,2) +---- +NULL + +query B +SELECT NULL NOT IN (0,1,2) +---- +NULL + +query B +SELECT 0 IN (0,1,2,NULL) +---- +true + +query B +SELECT 3 IN (0,1,2,NULL) +---- +NULL + +query B +SELECT 0 NOT IN (0,1,2,NULL) +---- +false + +query B +SELECT 3 NOT IN (0,1,2,NULL) +---- +NULL + +query B +SELECT 0.0 IN (0.0,0.1,0.2) +---- +true + +query B +SELECT 0.3 IN (0.0,0.1,0.2) +---- +false + +query B +SELECT 0.3 NOT IN (0.0,0.1,0.2) +---- +true + +query B +SELECT 0.0 NOT IN (0.0,0.1,0.2) +---- +false + +query B +SELECT NULL IN (0.0,0.1,0.2) +---- +NULL + +query B +SELECT NULL NOT IN (0.0,0.1,0.2) +---- +NULL + +query B +SELECT 0.0 IN (0.0,0.1,0.2,NULL) +---- +true + +query B +SELECT 0.3 IN (0.0,0.1,0.2,NULL) +---- +NULL + +query B +SELECT 0.0 NOT IN (0.0,0.1,0.2,NULL) +---- +false + +query B +SELECT 0.3 NOT IN (0.0,0.1,0.2,NULL) +---- +NULL + +query B +SELECT '1' IN ('a','b',1) +---- +true + +query B +SELECT '2' IN ('a','b',1) +---- +false + +query B +SELECT '2' NOT IN ('a','b',1) +---- +true + +query B +SELECT '1' NOT IN ('a','b',1) +---- +false + +query B +SELECT NULL IN ('a','b',1) +---- +NULL + +query B +SELECT NULL NOT IN ('a','b',1) +---- +NULL + +query B +SELECT '1' IN ('a','b',NULL,1) +---- +true + +query B +SELECT '2' IN ('a','b',NULL,1) +---- +NULL + +query B +SELECT '1' NOT IN ('a','b',NULL,1) +---- +false + +query B +SELECT '2' NOT IN ('a','b',NULL,1) +---- +NULL diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index e3e39ef6cc4c..1903088b0748 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -494,6 +494,10 @@ SELECT counter(*) from test; statement error Did you mean 'STDDEV'? SELECT STDEV(v1) from test; +# Aggregate function +statement error Did you mean 'COVAR'? +SELECT COVARIA(1,1); + # Window function statement error Did you mean 'SUM'? SELECT v1, v2, SUMM(v2) OVER(ORDER BY v1) from test; @@ -784,7 +788,7 @@ INSERT INTO products (product_id, product_name, price) VALUES (1, 'OldBrand Product 1', 19.99), (2, 'OldBrand Product 2', 29.99), (3, 'OldBrand Product 3', 39.99), -(4, 'OldBrand Product 4', 49.99) +(4, 'OldBrand Product 4', 49.99) query ITR SELECT * REPLACE (price*2 AS price) FROM products @@ -811,3 +815,189 @@ SELECT products.* REPLACE (price*2 AS price, product_id+1000 AS product_id) FROM 1002 OldBrand Product 2 59.98 1003 OldBrand Product 3 79.98 1004 OldBrand Product 4 99.98 + +#overlay tests +statement ok +CREATE TABLE over_test( + str TEXT, + characters TEXT, + pos INT, + len INT +) as VALUES + ('123', 'abc', 4, 5), + ('abcdefg', 'qwertyasdfg', 1, 7), + ('xyz', 'ijk', 1, 2), + ('Txxxxas', 'hom', 2, 4), + (NULL, 'hom', 2, 4), + ('Txxxxas', 'hom', NULL, 4), + ('Txxxxas', 'hom', 2, NULL), + ('Txxxxas', NULL, 2, 4) +; + +query T +SELECT overlay(str placing characters from pos for len) from over_test +---- +abc +qwertyasdfg +ijkz +Thomas +NULL +NULL +NULL +NULL + +query T +SELECT overlay(str placing characters from pos) from over_test +---- +abc +qwertyasdfg +ijk +Thomxas +NULL +NULL +Thomxas +NULL + +query I +SELECT levenshtein('kitten', 'sitting') +---- +3 + +query I +SELECT levenshtein('kitten', NULL) +---- +NULL + +query ? +SELECT levenshtein(NULL, 'sitting') +---- +NULL + +query ? +SELECT levenshtein(NULL, NULL) +---- +NULL + +query T +SELECT substr_index('www.apache.org', '.', 1) +---- +www + +query T +SELECT substr_index('www.apache.org', '.', 2) +---- +www.apache + +query T +SELECT substr_index('www.apache.org', '.', -1) +---- +org + +query T +SELECT substr_index('www.apache.org', '.', -2) +---- +apache.org + +query T +SELECT substr_index('www.apache.org', 'ac', 1) +---- +www.ap + +query T +SELECT substr_index('www.apache.org', 'ac', -1) +---- +he.org + +query T +SELECT substr_index('www.apache.org', 'ac', 2) +---- +www.apache.org + +query T +SELECT substr_index('www.apache.org', 'ac', -2) +---- +www.apache.org + +query ? +SELECT substr_index(NULL, 'ac', 1) +---- +NULL + +query T +SELECT substr_index('www.apache.org', NULL, 1) +---- +NULL + +query T +SELECT substr_index('www.apache.org', 'ac', NULL) +---- +NULL + +query T +SELECT substr_index('', 'ac', 1) +---- +(empty) + +query T +SELECT substr_index('www.apache.org', '', 1) +---- +(empty) + +query T +SELECT substr_index('www.apache.org', 'ac', 0) +---- +(empty) + +query ? +SELECT substr_index(NULL, NULL, NULL) +---- +NULL + +query I +SELECT find_in_set('b', 'a,b,c,d') +---- +2 + + +query I +SELECT find_in_set('a', 'a,b,c,d,a') +---- +1 + +query I +SELECT find_in_set('', 'a,b,c,d,a') +---- +0 + +query I +SELECT find_in_set('a', '') +---- +0 + + +query I +SELECT find_in_set('', '') +---- +1 + +query ? +SELECT find_in_set(NULL, 'a,b,c,d') +---- +NULL + +query I +SELECT find_in_set('a', NULL) +---- +NULL + + +query ? +SELECT find_in_set(NULL, NULL) +---- +NULL + +# Verify that multiple calls to volatile functions like `random()` are not combined / optimized away +query B +SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() r2) WHERE r1 > 0 AND r2 > 0) +---- +false diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index ffef93837b27..b09ff79e88d5 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -2014,23 +2014,21 @@ Sort: l.col0 ASC NULLS LAST ----------TableScan: tab0 projection=[col0, col1] physical_plan SortPreservingMergeExec: [col0@0 ASC NULLS LAST] ---ProjectionExec: expr=[col0@0 as col0, LAST_VALUE(r.col1) ORDER BY [r.col0 ASC NULLS LAST]@3 as last_col1] -----AggregateExec: mode=FinalPartitioned, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)], ordering_mode=PartiallyOrdered -------SortExec: expr=[col0@0 ASC NULLS LAST] +--SortExec: expr=[col0@0 ASC NULLS LAST] +----ProjectionExec: expr=[col0@0 as col0, LAST_VALUE(r.col1) ORDER BY [r.col0 ASC NULLS LAST]@3 as last_col1] +------AggregateExec: mode=FinalPartitioned, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([col0@0, col1@1, col2@2], 4), input_partitions=4 -------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)], ordering_mode=PartiallyOrdered ---------------SortExec: expr=[col0@3 ASC NULLS LAST] +------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)] +--------------ProjectionExec: expr=[col0@2 as col0, col1@3 as col1, col2@4 as col2, col0@0 as col0, col1@1 as col1] ----------------CoalesceBatchesExec: target_batch_size=8192 ------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=4 -------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------MemoryExec: partitions=1, partition_sizes=[3] +----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[3] --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=4 -------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------MemoryExec: partitions=1, partition_sizes=[3] +----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[3] # Columns in the table are a,b,c,d. Source is CsvExec which is ordered by # a,b,c column. Column a has cardinality 2, column b has cardinality 4. @@ -2086,9 +2084,7 @@ logical_plan Projection: multiple_ordered_table.a --Sort: multiple_ordered_table.c ASC NULLS LAST ----TableScan: multiple_ordered_table projection=[a, c] -physical_plan -ProjectionExec: expr=[a@0 as a] ---CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_ordering=[a@0 ASC NULLS LAST], has_header=true +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a], output_ordering=[a@0 ASC NULLS LAST], has_header=true # Final plan shouldn't have SortExec a ASC, b ASC, # because table already satisfies this ordering. @@ -2099,9 +2095,7 @@ logical_plan Projection: multiple_ordered_table.a --Sort: multiple_ordered_table.a ASC NULLS LAST, multiple_ordered_table.b ASC NULLS LAST ----TableScan: multiple_ordered_table projection=[a, b] -physical_plan -ProjectionExec: expr=[a@0 as a] ---CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a], output_ordering=[a@0 ASC NULLS LAST], has_header=true # test_window_agg_sort statement ok @@ -2120,8 +2114,8 @@ Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, SUM(annotate ----TableScan: annotated_data_infinite2 projection=[a, b, c] physical_plan ProjectionExec: expr=[a@1 as a, b@0 as b, SUM(annotated_data_infinite2.c)@2 as summation1] ---AggregateExec: mode=Single, gby=[b@1 as b, a@0 as a], aggr=[SUM(annotated_data_infinite2.c)], ordering_mode=FullyOrdered -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +--AggregateExec: mode=Single, gby=[b@1 as b, a@0 as a], aggr=[SUM(annotated_data_infinite2.c)], ordering_mode=Sorted +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III @@ -2151,8 +2145,8 @@ Projection: annotated_data_infinite2.a, annotated_data_infinite2.d, SUM(annotate ----TableScan: annotated_data_infinite2 projection=[a, c, d] physical_plan ProjectionExec: expr=[a@1 as a, d@0 as d, SUM(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as summation1] ---AggregateExec: mode=Single, gby=[d@2 as d, a@0 as a], aggr=[SUM(annotated_data_infinite2.c)], ordering_mode=PartiallyOrdered -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true +--AggregateExec: mode=Single, gby=[d@2 as d, a@0 as a], aggr=[SUM(annotated_data_infinite2.c)], ordering_mode=PartiallySorted([1]) +----StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] query III SELECT a, d, @@ -2184,8 +2178,8 @@ Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, FIRST_VALUE( ----TableScan: annotated_data_infinite2 projection=[a, b, c] physical_plan ProjectionExec: expr=[a@0 as a, b@1 as b, FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as first_c] ---AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[FIRST_VALUE(annotated_data_infinite2.c)], ordering_mode=FullyOrdered -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[FIRST_VALUE(annotated_data_infinite2.c)], ordering_mode=Sorted +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III SELECT a, b, FIRST_VALUE(c ORDER BY a DESC) as first_c @@ -2210,11 +2204,11 @@ Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, LAST_VALUE(a ----TableScan: annotated_data_infinite2 projection=[a, b, c] physical_plan ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as last_c] ---AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c)], ordering_mode=FullyOrdered -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c)], ordering_mode=Sorted +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III -SELECT a, b, LAST_VALUE(c ORDER BY a DESC) as last_c +SELECT a, b, LAST_VALUE(c ORDER BY a DESC, c ASC) as last_c FROM annotated_data_infinite2 GROUP BY a, b ---- @@ -2237,8 +2231,8 @@ Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, LAST_VALUE(a ----TableScan: annotated_data_infinite2 projection=[a, b, c] physical_plan ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c)@2 as last_c] ---AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c)], ordering_mode=FullyOrdered -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c)], ordering_mode=Sorted +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III SELECT a, b, LAST_VALUE(c) as last_c @@ -2335,15 +2329,15 @@ ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amou ----SortExec: expr=[amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query T?R +query T?R rowsort SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM sales_global AS s GROUP BY s.country ---- FRA [200.0, 50.0] 250 -TUR [100.0, 75.0] 175 GRC [80.0, 30.0] 110 +TUR [100.0, 75.0] 175 # test_ordering_sensitive_aggregation3 # When different aggregators have conflicting requirements, we cannot satisfy all of them in current implementation. @@ -2356,9 +2350,9 @@ SELECT ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, GROUP BY s.country # test_ordering_sensitive_aggregation4 -# If aggregators can work with bounded memory (FullyOrdered or PartiallyOrdered mode), we should append requirement to +# If aggregators can work with bounded memory (Sorted or PartiallySorted mode), we should append requirement to # the existing ordering. This enables us to still work with bounded memory, and also satisfy aggregation requirement. -# This test checks for whether we can satisfy aggregation requirement in FullyOrdered mode. +# This test checks for whether we can satisfy aggregation requirement in Sorted mode. query TT EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 @@ -2375,11 +2369,11 @@ Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] --------TableScan: sales_global projection=[country, amount] physical_plan ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=FullyOrdered +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=Sorted ----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query T?R +query T?R rowsort SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM (SELECT * @@ -2392,9 +2386,9 @@ GRC [80.0, 30.0] 110 TUR [100.0, 75.0] 175 # test_ordering_sensitive_aggregation5 -# If aggregators can work with bounded memory (FullyOrdered or PartiallyOrdered mode), we should be append requirement to +# If aggregators can work with bounded memory (Sorted or PartiallySorted mode), we should be append requirement to # the existing ordering. This enables us to still work with bounded memory, and also satisfy aggregation requirement. -# This test checks for whether we can satisfy aggregation requirement in PartiallyOrdered mode. +# This test checks for whether we can satisfy aggregation requirement in PartiallySorted mode. query TT EXPLAIN SELECT s.country, s.zip_code, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 @@ -2411,11 +2405,11 @@ Projection: s.country, s.zip_code, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC N --------TableScan: sales_global projection=[zip_code, country, amount] physical_plan ProjectionExec: expr=[country@0 as country, zip_code@1 as zip_code, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@2 as amounts, SUM(s.amount)@3 as sum1] ---AggregateExec: mode=Single, gby=[country@1 as country, zip_code@0 as zip_code], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=PartiallyOrdered +--AggregateExec: mode=Single, gby=[country@1 as country, zip_code@0 as zip_code], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=PartiallySorted([0]) ----SortExec: expr=[country@1 ASC NULLS LAST,amount@2 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query TI?R +query TI?R rowsort SELECT s.country, s.zip_code, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM (SELECT * @@ -2428,7 +2422,7 @@ GRC 0 [80.0, 30.0] 110 TUR 1 [100.0, 75.0] 175 # test_ordering_sensitive_aggregation6 -# If aggregators can work with bounded memory (FullyOrdered or PartiallyOrdered mode), we should be append requirement to +# If aggregators can work with bounded memory (FullySorted or PartiallySorted mode), we should be append requirement to # the existing ordering. When group by expressions contain aggregation requirement, we shouldn't append redundant expression. # Hence in the final plan SortExec should be SortExec: expr=[country@0 DESC] not SortExec: expr=[country@0 ASC NULLS LAST,country@0 DESC] query TT @@ -2447,11 +2441,11 @@ Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST] --------TableScan: sales_global projection=[country, amount] physical_plan ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=FullyOrdered +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=Sorted ----SortExec: expr=[country@0 ASC NULLS LAST] ------MemoryExec: partitions=1, partition_sizes=[1] -query T?R +query T?R rowsort SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM (SELECT * @@ -2482,11 +2476,11 @@ Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, --------TableScan: sales_global projection=[country, amount] physical_plan ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=FullyOrdered +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=Sorted ----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query T?R +query T?R rowsort SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC, s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM (SELECT * @@ -2518,7 +2512,7 @@ ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER ----SortExec: expr=[amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query T?RR +query T?RR rowsort SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, LAST_VALUE(amount ORDER BY amount DESC) AS fv2 @@ -2526,8 +2520,8 @@ SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, GROUP BY country ---- FRA [200.0, 50.0] 50 50 -TUR [100.0, 75.0] 75 75 GRC [80.0, 30.0] 30 30 +TUR [100.0, 75.0] 75 75 # test_reverse_aggregate_expr2 # Some of the Aggregators can be reversed, by this way we can still run aggregators without re-ordering @@ -2643,10 +2637,9 @@ Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sal physical_plan ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] --AggregateExec: mode=Single, gby=[country@0 as country], aggr=[LAST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount), SUM(sales_global.amount)] -----SortExec: expr=[ts@1 ASC NULLS LAST] -------MemoryExec: partitions=1, partition_sizes=[1] +----MemoryExec: partitions=1, partition_sizes=[1] -query TRRR +query TRRR rowsort SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, LAST_VALUE(amount ORDER BY ts DESC) as lv1, SUM(amount ORDER BY ts DESC) as sum1 @@ -2655,8 +2648,8 @@ SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, ORDER BY ts ASC) GROUP BY country ---- -GRC 80 30 110 FRA 200 50 250 +GRC 80 30 110 TUR 100 75 175 # If existing ordering doesn't satisfy requirement, we should do calculations @@ -2677,19 +2670,18 @@ Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sal physical_plan ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] --AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount), SUM(sales_global.amount)] -----SortExec: expr=[ts@1 DESC] -------MemoryExec: partitions=1, partition_sizes=[1] +----MemoryExec: partitions=1, partition_sizes=[1] -query TRRR +query TRRR rowsort SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, LAST_VALUE(amount ORDER BY ts DESC) as lv1, SUM(amount ORDER BY ts DESC) as sum1 FROM sales_global GROUP BY country ---- -TUR 100 75 175 -GRC 80 30 110 FRA 200 50 250 +GRC 80 30 110 +TUR 100 75 175 query TT EXPLAIN SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate @@ -2714,14 +2706,13 @@ physical_plan SortExec: expr=[sn@2 ASC NULLS LAST] --ProjectionExec: expr=[zip_code@1 as zip_code, country@2 as country, sn@0 as sn, ts@3 as ts, currency@4 as currency, LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST]@5 as last_rate] ----AggregateExec: mode=Single, gby=[sn@2 as sn, zip_code@0 as zip_code, country@1 as country, ts@3 as ts, currency@4 as currency], aggr=[LAST_VALUE(e.amount)] -------SortExec: expr=[sn@5 ASC NULLS LAST] ---------ProjectionExec: expr=[zip_code@0 as zip_code, country@1 as country, sn@2 as sn, ts@3 as ts, currency@4 as currency, sn@5 as sn, amount@8 as amount] -----------CoalesceBatchesExec: target_batch_size=8192 -------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@4, currency@2)], filter=ts@0 >= ts@1 ---------------MemoryExec: partitions=1, partition_sizes=[1] ---------------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[zip_code@4 as zip_code, country@5 as country, sn@6 as sn, ts@7 as ts, currency@8 as currency, sn@0 as sn, amount@3 as amount] +--------CoalesceBatchesExec: target_batch_size=8192 +----------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@2, currency@4)], filter=ts@0 >= ts@1 +------------MemoryExec: partitions=1, partition_sizes=[1] +------------MemoryExec: partitions=1, partition_sizes=[1] -query ITIPTR +query ITIPTR rowsort SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate FROM sales_global AS s JOIN sales_global AS e @@ -2731,10 +2722,10 @@ GROUP BY s.sn, s.zip_code, s.country, s.ts, s.currency ORDER BY s.sn ---- 0 GRC 0 2022-01-01T06:00:00 EUR 30 +0 GRC 4 2022-01-03T10:00:00 EUR 80 1 FRA 1 2022-01-01T08:00:00 EUR 50 -1 TUR 2 2022-01-01T11:30:00 TRY 75 1 FRA 3 2022-01-02T12:00:00 EUR 200 -0 GRC 4 2022-01-03T10:00:00 EUR 80 +1 TUR 2 2022-01-01T11:30:00 TRY 75 1 TUR 4 2022-01-03T10:00:00 TRY 100 # Run order-sensitive aggregators in multiple partitions @@ -2762,10 +2753,9 @@ SortPreservingMergeExec: [country@0 ASC NULLS LAST] ------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 -------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ---------------SortExec: expr=[ts@1 ASC NULLS LAST] -----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -------------------MemoryExec: partitions=1, partition_sizes=[1] +------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +----------------MemoryExec: partitions=1, partition_sizes=[1] query TRR SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, @@ -2796,13 +2786,12 @@ physical_plan SortPreservingMergeExec: [country@0 ASC NULLS LAST] --SortExec: expr=[country@0 ASC NULLS LAST] ----ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as fv2] -------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 -------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] ---------------SortExec: expr=[ts@1 ASC NULLS LAST] -----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -------------------MemoryExec: partitions=1, partition_sizes=[1] +------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +----------------MemoryExec: partitions=1, partition_sizes=[1] query TRR SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, @@ -2815,6 +2804,11 @@ FRA 50 50 GRC 30 30 TUR 75 75 +# make sure that batch size is small. So that query below runs in multi partitions +# row number of the sales_global is 5. Hence we choose batch size 4 to make is smaller. +statement ok +set datafusion.execution.batch_size = 4; + # order-sensitive FIRST_VALUE and LAST_VALUE aggregators should work in # multi-partitions without group by also. query TT @@ -2831,16 +2825,15 @@ ProjectionExec: expr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts --AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ----CoalescePartitionsExec ------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ---------SortExec: expr=[ts@0 ASC NULLS LAST] -----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -------------MemoryExec: partitions=1, partition_sizes=[1] +--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] query RR SELECT FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, LAST_VALUE(amount ORDER BY ts ASC) AS fv2 FROM sales_global ---- -30 80 +30 100 # Conversion in between FIRST_VALUE and LAST_VALUE to resolve # contradictory requirements should work in multi partitions. @@ -2855,12 +2848,11 @@ Projection: FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS ----TableScan: sales_global projection=[ts, amount] physical_plan ProjectionExec: expr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv2] ---AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +--AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ----CoalescePartitionsExec -------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] ---------SortExec: expr=[ts@0 ASC NULLS LAST] -----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -------------MemoryExec: partitions=1, partition_sizes=[1] +------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] query RR SELECT FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, @@ -2958,7 +2950,7 @@ SortPreservingMergeExec: [country@0 ASC NULLS LAST] --SortExec: expr=[country@0 ASC NULLS LAST] ----ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as array_agg1] ------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount)] ---------CoalesceBatchesExec: target_batch_size=8192 +--------CoalesceBatchesExec: target_batch_size=4 ----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 ------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount)] --------------SortExec: expr=[amount@1 ASC NULLS LAST] @@ -2993,8 +2985,8 @@ physical_plan SortPreservingMergeExec: [country@0 ASC NULLS LAST] --SortExec: expr=[country@0 ASC NULLS LAST] ----ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] -------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ---------CoalesceBatchesExec: target_batch_size=8192 +------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--------CoalesceBatchesExec: target_batch_size=4 ----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 ------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------------SortExec: expr=[amount@1 DESC] @@ -3093,6 +3085,55 @@ CREATE TABLE sales_global_with_pk_alternate (zip_code INT, (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0) +# we do not currently support foreign key constraints. +statement error DataFusion error: Error during planning: Foreign key constraints are not currently supported +CREATE TABLE sales_global_with_foreign_key (zip_code INT, + country VARCHAR(3), + sn INT references sales_global_with_pk_alternate(sn), + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT +) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0) + +# we do not currently support foreign key +statement error DataFusion error: Error during planning: Foreign key constraints are not currently supported +CREATE TABLE sales_global_with_foreign_key (zip_code INT, + country VARCHAR(3), + sn INT REFERENCES sales_global_with_pk_alternate(sn), + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT +) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0) + +# we do not currently support foreign key +# foreign key can be defined with a different syntax. +# we should get the same error. +statement error DataFusion error: Error during planning: Foreign key constraints are not currently supported +CREATE TABLE sales_global_with_foreign_key (zip_code INT, + country VARCHAR(3), + sn INT, + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT, + FOREIGN KEY (sn) + REFERENCES sales_global_with_pk_alternate(sn) +) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0) + # create a table for testing, where primary key is composite statement ok CREATE TABLE sales_global_with_composite_pk (zip_code INT, @@ -3146,7 +3187,7 @@ SortPreservingMergeExec: [sn@0 ASC NULLS LAST] --SortExec: expr=[sn@0 ASC NULLS LAST] ----ProjectionExec: expr=[sn@0 as sn, amount@1 as amount, 2 * CAST(sn@0 AS Int64) as Int64(2) * s.sn] ------AggregateExec: mode=FinalPartitioned, gby=[sn@0 as sn, amount@1 as amount], aggr=[] ---------CoalesceBatchesExec: target_batch_size=8192 +--------CoalesceBatchesExec: target_batch_size=4 ----------RepartitionExec: partitioning=Hash([sn@0, amount@1], 8), input_partitions=8 ------------AggregateExec: mode=Partial, gby=[sn@0 as sn, amount@1 as amount], aggr=[] --------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] @@ -3163,6 +3204,21 @@ SELECT s.sn, s.amount, 2*s.sn 3 200 6 4 100 8 +# we should be able to re-write group by expression +# using functional dependencies for complex expressions also. +# In this case, we use 2*s.amount instead of s.amount. +query IRI +SELECT s.sn, 2*s.amount, 2*s.sn + FROM sales_global_with_pk AS s + GROUP BY sn + ORDER BY sn +---- +0 60 0 +1 100 2 +2 150 4 +3 400 6 +4 200 8 + query IRI SELECT s.sn, s.amount, 2*s.sn FROM sales_global_with_pk_alternate AS s @@ -3199,7 +3255,7 @@ SortPreservingMergeExec: [sn@0 ASC NULLS LAST] --SortExec: expr=[sn@0 ASC NULLS LAST] ----ProjectionExec: expr=[sn@0 as sn, SUM(l.amount)@2 as SUM(l.amount), amount@1 as amount] ------AggregateExec: mode=FinalPartitioned, gby=[sn@0 as sn, amount@1 as amount], aggr=[SUM(l.amount)] ---------CoalesceBatchesExec: target_batch_size=8192 +--------CoalesceBatchesExec: target_batch_size=4 ----------RepartitionExec: partitioning=Hash([sn@0, amount@1], 8), input_partitions=8 ------------AggregateExec: mode=Partial, gby=[sn@1 as sn, amount@2 as amount], aggr=[SUM(l.amount)] --------------ProjectionExec: expr=[amount@1 as amount, sn@2 as sn, amount@3 as amount] @@ -3316,7 +3372,7 @@ SELECT column1, COUNT(*) as column2 FROM (VALUES (['a', 'b'], 1), (['c', 'd', 'e # primary key should be aware from which columns it is associated -statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.sn could not be resolved from available columns: l.sn, SUM\(l.amount\) +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.sn could not be resolved from available columns: l.sn, l.zip_code, l.country, l.ts, l.currency, l.amount, SUM\(l.amount\) SELECT l.sn, r.sn, SUM(l.amount), r.amount FROM sales_global_with_pk AS l JOIN sales_global_with_pk AS r @@ -3347,7 +3403,7 @@ SortPreservingMergeExec: [sn@2 ASC NULLS LAST] --SortExec: expr=[sn@2 ASC NULLS LAST] ----ProjectionExec: expr=[zip_code@1 as zip_code, country@2 as country, sn@0 as sn, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum_amount@6 as sum_amount] ------AggregateExec: mode=FinalPartitioned, gby=[sn@0 as sn, zip_code@1 as zip_code, country@2 as country, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum_amount@6 as sum_amount], aggr=[] ---------CoalesceBatchesExec: target_batch_size=8192 +--------CoalesceBatchesExec: target_batch_size=4 ----------RepartitionExec: partitioning=Hash([sn@0, zip_code@1, country@2, ts@3, currency@4, amount@5, sum_amount@6], 8), input_partitions=8 ------------AggregateExec: mode=Partial, gby=[sn@2 as sn, zip_code@0 as zip_code, country@1 as country, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum_amount@6 as sum_amount], aggr=[] --------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 @@ -3408,7 +3464,7 @@ ORDER BY r.sn 4 100 2022-01-03T10:00:00 # after join, new window expressions shouldn't be associated with primary keys -statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression rn1 could not be resolved from available columns: r.sn, SUM\(r.amount\) +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression rn1 could not be resolved from available columns: r.sn, r.ts, r.amount, SUM\(r.amount\) SELECT r.sn, SUM(r.amount), rn1 FROM (SELECT r.ts, r.sn, r.amount, @@ -3517,6 +3573,12 @@ ORDER BY y; 2 1 3 1 +# Make sure to choose a batch size smaller than, row number of the table. +# In this case we choose 2 (Row number of the table is 3). +# otherwise we won't see parallelism in tests. +statement ok +set datafusion.execution.batch_size = 2; + # plan of the query above should contain partial # and final aggregation stages query TT @@ -3530,7 +3592,8 @@ physical_plan AggregateExec: mode=Final, gby=[], aggr=[LAST_VALUE(foo.x)] --CoalescePartitionsExec ----AggregateExec: mode=Partial, gby=[], aggr=[LAST_VALUE(foo.x)] -------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] +------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] query I SELECT FIRST_VALUE(x) @@ -3551,8 +3614,53 @@ physical_plan AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(foo.x)] --CoalescePartitionsExec ----AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(foo.x)] -------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] +------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] +# Since both ordering requirements are satisfied, there shouldn't be +# any SortExec in the final plan. +query TT +EXPLAIN SELECT FIRST_VALUE(a ORDER BY a ASC) as first_a, + LAST_VALUE(c ORDER BY c DESC) as last_c +FROM multiple_ordered_table +GROUP BY d; +---- +logical_plan +Projection: FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST] AS first_a, LAST_VALUE(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] AS last_c +--Aggregate: groupBy=[[multiple_ordered_table.d]], aggr=[[FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST], LAST_VALUE(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]]] +----TableScan: multiple_ordered_table projection=[a, c, d] +physical_plan +ProjectionExec: expr=[FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST]@1 as first_a, LAST_VALUE(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]@2 as last_c] +--AggregateExec: mode=FinalPartitioned, gby=[d@0 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), FIRST_VALUE(multiple_ordered_table.c)] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([d@0], 8), input_partitions=8 +--------AggregateExec: mode=Partial, gby=[d@2 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), FIRST_VALUE(multiple_ordered_table.c)] +----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +query II rowsort +SELECT FIRST_VALUE(a ORDER BY a ASC) as first_a, + LAST_VALUE(c ORDER BY c DESC) as last_c +FROM multiple_ordered_table +GROUP BY d; +---- +0 0 +0 1 +0 15 +0 4 +0 9 + +query III rowsort +SELECT d, FIRST_VALUE(c ORDER BY a DESC, c DESC) as first_a, + LAST_VALUE(c ORDER BY c DESC) as last_c +FROM multiple_ordered_table +GROUP BY d; +---- +0 95 0 +1 90 4 +2 97 1 +3 99 15 +4 98 9 query TT EXPLAIN SELECT c @@ -3563,3 +3671,616 @@ logical_plan Sort: multiple_ordered_table.c ASC NULLS LAST --TableScan: multiple_ordered_table projection=[c] physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +statement ok +set datafusion.execution.target_partitions = 1; + +query TT +EXPLAIN SELECT LAST_VALUE(l.d ORDER BY l.a) AS amount_usd +FROM multiple_ordered_table AS l +INNER JOIN ( + SELECT *, ROW_NUMBER() OVER (ORDER BY r.a) as row_n FROM multiple_ordered_table AS r +) +ON l.d = r.d AND + l.a >= r.a - 10 +GROUP BY row_n +ORDER BY row_n +---- +logical_plan +Projection: amount_usd +--Sort: row_n ASC NULLS LAST +----Projection: LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST] AS amount_usd, row_n +------Aggregate: groupBy=[[row_n]], aggr=[[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]]] +--------Projection: l.a, l.d, row_n +----------Inner Join: l.d = r.d Filter: CAST(l.a AS Int64) >= CAST(r.a AS Int64) - Int64(10) +------------SubqueryAlias: l +--------------TableScan: multiple_ordered_table projection=[a, d] +------------Projection: r.a, r.d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_n +--------------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----------------SubqueryAlias: r +------------------TableScan: multiple_ordered_table projection=[a, d] +physical_plan +ProjectionExec: expr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd] +--AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[LAST_VALUE(l.d)], ordering_mode=Sorted +----ProjectionExec: expr=[a@0 as a, d@1 as d, row_n@4 as row_n] +------CoalesceBatchesExec: target_batch_size=2 +--------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +----------ProjectionExec: expr=[a@0 as a, d@1 as d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] +------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true + +# reset partition number to 8. +statement ok +set datafusion.execution.target_partitions = 8; + +# Create an external table with primary key +# column c +statement ok +CREATE EXTERNAL TABLE multiple_ordered_table_with_pk ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER, + primary key(c) +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# We can use column b during selection +# even if it is not among group by expressions +# because column c is primary key. +query TT +EXPLAIN SELECT c, b, SUM(d) +FROM multiple_ordered_table_with_pk +GROUP BY c; +---- +logical_plan +Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +--TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[c@0 as c, b@1 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +--SortExec: expr=[c@0 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([c@0, b@1], 8), input_partitions=8 +--------AggregateExec: mode=Partial, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +# drop table multiple_ordered_table_with_pk +statement ok +drop table multiple_ordered_table_with_pk; + +# Create an external table with primary key +# column c, in this case use alternative syntax +# for defining primary key +statement ok +CREATE EXTERNAL TABLE multiple_ordered_table_with_pk ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER primary key, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# We can use column b during selection +# even if it is not among group by expressions +# because column c is primary key. +query TT +EXPLAIN SELECT c, b, SUM(d) +FROM multiple_ordered_table_with_pk +GROUP BY c; +---- +logical_plan +Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +--TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[c@0 as c, b@1 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +--SortExec: expr=[c@0 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([c@0, b@1], 8), input_partitions=8 +--------AggregateExec: mode=Partial, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +statement ok +set datafusion.execution.target_partitions = 1; + +query TT +EXPLAIN SELECT c, sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) +GROUP BY c; +---- +logical_plan +Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, sum1]], aggr=[[]] +--Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[c, d] +physical_plan +AggregateExec: mode=Single, gby=[c@0 as c, sum1@1 as sum1], aggr=[], ordering_mode=PartiallySorted([0]) +--ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +----AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT c, sum1, SUM(b) OVER() as sumb + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c); +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, sum1, SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS sumb +--WindowAggr: windowExpr=[[SUM(CAST(multiple_ordered_table_with_pk.b AS Int64)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +----Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +--------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, sum1@2 as sum1, SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as sumb] +--WindowAggExec: wdw=[SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }] +----ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as lhs + JOIN + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as rhs + ON lhs.b=rhs.b; +---- +logical_plan +Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 +--Inner Join: lhs.b = rhs.b +----SubqueryAlias: lhs +------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +----SubqueryAlias: rhs +------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, c@3 as c, sum1@2 as sum1, sum1@5 as sum1] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(b@1, b@1)] +------ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true +------ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as lhs + CROSS JOIN + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as rhs; +---- +logical_plan +Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 +--CrossJoin: +----SubqueryAlias: lhs +------Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[c, d] +----SubqueryAlias: rhs +------Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, c@2 as c, sum1@1 as sum1, sum1@3 as sum1] +--CrossJoinExec +----ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true +----ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +# we do not generate physical plan for Repartition yet (e.g Distribute By queries). +query TT +EXPLAIN SELECT a, b, sum1 +FROM (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) +DISTRIBUTE BY a +---- +logical_plan +Repartition: DistributeBy(a) +--Projection: multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, b, c, d] + +# union with aggregate +query TT +EXPLAIN SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +UNION ALL + SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +---- +logical_plan +Union +--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +UnionExec +--ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true +--ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +# table scan should be simplified. +query TT +EXPLAIN SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +# limit should be simplified +query TT +EXPLAIN SELECT * + FROM (SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c + LIMIT 5) +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--Limit: skip=0, fetch=5 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--GlobalLimitExec: skip=0, fetch=5 +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +statement ok +set datafusion.execution.target_partitions = 8; + +# Tests for single distinct to group by optimization rule +statement ok +CREATE TABLE t(x int) AS VALUES (1), (2), (1); + +statement ok +create table t1(x bigint,y int) as values (9223372036854775807,2), (9223372036854775806,2); + +query II +SELECT SUM(DISTINCT x), MAX(DISTINCT x) from t GROUP BY x ORDER BY x; +---- +1 1 +2 2 + +query II +SELECT MAX(DISTINCT x), SUM(DISTINCT x) from t GROUP BY x ORDER BY x; +---- +1 1 +2 2 + +query TT +EXPLAIN SELECT SUM(DISTINCT CAST(x AS DOUBLE)), MAX(DISTINCT x) FROM t1 GROUP BY y; +---- +logical_plan +Projection: SUM(DISTINCT t1.x), MAX(DISTINCT t1.x) +--Aggregate: groupBy=[[t1.y]], aggr=[[SUM(DISTINCT CAST(t1.x AS Float64)), MAX(DISTINCT t1.x)]] +----TableScan: t1 projection=[x, y] +physical_plan +ProjectionExec: expr=[SUM(DISTINCT t1.x)@1 as SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)@2 as MAX(DISTINCT t1.x)] +--AggregateExec: mode=FinalPartitioned, gby=[y@0 as y], aggr=[SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([y@0], 8), input_partitions=8 +--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------AggregateExec: mode=Partial, gby=[y@1 as y], aggr=[SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)] +------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT SUM(DISTINCT CAST(x AS DOUBLE)), MAX(DISTINCT CAST(x AS DOUBLE)) FROM t1 GROUP BY y; +---- +logical_plan +Projection: SUM(alias1) AS SUM(DISTINCT t1.x), MAX(alias1) AS MAX(DISTINCT t1.x) +--Aggregate: groupBy=[[t1.y]], aggr=[[SUM(alias1), MAX(alias1)]] +----Aggregate: groupBy=[[t1.y, CAST(t1.x AS Float64)t1.x AS t1.x AS alias1]], aggr=[[]] +------Projection: CAST(t1.x AS Float64) AS CAST(t1.x AS Float64)t1.x, t1.y +--------TableScan: t1 projection=[x, y] +physical_plan +ProjectionExec: expr=[SUM(alias1)@1 as SUM(DISTINCT t1.x), MAX(alias1)@2 as MAX(DISTINCT t1.x)] +--AggregateExec: mode=FinalPartitioned, gby=[y@0 as y], aggr=[SUM(alias1), MAX(alias1)] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([y@0], 8), input_partitions=8 +--------AggregateExec: mode=Partial, gby=[y@0 as y], aggr=[SUM(alias1), MAX(alias1)] +----------AggregateExec: mode=FinalPartitioned, gby=[y@0 as y, alias1@1 as alias1], aggr=[] +------------CoalesceBatchesExec: target_batch_size=2 +--------------RepartitionExec: partitioning=Hash([y@0, alias1@1], 8), input_partitions=8 +----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------AggregateExec: mode=Partial, gby=[y@1 as y, CAST(t1.x AS Float64)t1.x@0 as alias1], aggr=[] +--------------------ProjectionExec: expr=[CAST(x@0 AS Float64) as CAST(t1.x AS Float64)t1.x, y@1 as y] +----------------------MemoryExec: partitions=1, partition_sizes=[1] + +# create an unbounded table that contains ordered timestamp. +statement ok +CREATE UNBOUNDED EXTERNAL TABLE unbounded_csv_with_timestamps ( + name VARCHAR, + ts TIMESTAMP +) +STORED AS CSV +WITH ORDER (ts DESC) +LOCATION '../core/tests/data/timestamps.csv' + +# below query should work in streaming mode. +query TT +EXPLAIN SELECT date_bin('15 minutes', ts) as time_chunks + FROM unbounded_csv_with_timestamps + GROUP BY date_bin('15 minutes', ts) + ORDER BY time_chunks DESC + LIMIT 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: time_chunks DESC NULLS FIRST, fetch=5 +----Projection: date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts) AS time_chunks +------Aggregate: groupBy=[[date_bin(IntervalMonthDayNano("900000000000"), unbounded_csv_with_timestamps.ts) AS date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)]], aggr=[[]] +--------TableScan: unbounded_csv_with_timestamps projection=[ts] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--SortPreservingMergeExec: [time_chunks@0 DESC], fetch=5 +----ProjectionExec: expr=[date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 as time_chunks] +------AggregateExec: mode=FinalPartitioned, gby=[date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 as date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)], aggr=[], ordering_mode=Sorted +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0], 8), input_partitions=8, preserve_order=true, sort_exprs=date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 DESC +------------AggregateExec: mode=Partial, gby=[date_bin(900000000000, ts@0) as date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)], aggr=[], ordering_mode=Sorted +--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------------StreamingTableExec: partition_sizes=1, projection=[ts], infinite_source=true, output_ordering=[ts@0 DESC] + +query P +SELECT date_bin('15 minutes', ts) as time_chunks + FROM unbounded_csv_with_timestamps + GROUP BY date_bin('15 minutes', ts) + ORDER BY time_chunks DESC + LIMIT 5; +---- +2018-12-13T12:00:00 +2018-11-13T17:00:00 + +# Since extract is not a monotonic function, below query should not run. +# when source is unbounded. +query error +SELECT extract(month from ts) as months + FROM unbounded_csv_with_timestamps + GROUP BY extract(month from ts) + ORDER BY months DESC + LIMIT 5; + +# Create a table where timestamp is ordered +statement ok +CREATE EXTERNAL TABLE csv_with_timestamps ( + name VARCHAR, + ts TIMESTAMP +) +STORED AS CSV +WITH ORDER (ts DESC) +LOCATION '../core/tests/data/timestamps.csv'; + +# below query should run since it operates on a bounded source and have a sort +# at the top of its plan. +query TT +EXPLAIN SELECT extract(month from ts) as months + FROM csv_with_timestamps + GROUP BY extract(month from ts) + ORDER BY months DESC + LIMIT 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: months DESC NULLS FIRST, fetch=5 +----Projection: date_part(Utf8("MONTH"),csv_with_timestamps.ts) AS months +------Aggregate: groupBy=[[date_part(Utf8("MONTH"), csv_with_timestamps.ts)]], aggr=[[]] +--------TableScan: csv_with_timestamps projection=[ts] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--SortPreservingMergeExec: [months@0 DESC], fetch=5 +----SortExec: TopK(fetch=5), expr=[months@0 DESC] +------ProjectionExec: expr=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as months] +--------AggregateExec: mode=FinalPartitioned, gby=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +----------CoalesceBatchesExec: target_batch_size=2 +------------RepartitionExec: partitioning=Hash([date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0], 8), input_partitions=8 +--------------AggregateExec: mode=Partial, gby=[date_part(MONTH, ts@0) as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/timestamps.csv]]}, projection=[ts], output_ordering=[ts@0 DESC], has_header=false + +query R +SELECT extract(month from ts) as months + FROM csv_with_timestamps + GROUP BY extract(month from ts) + ORDER BY months DESC + LIMIT 5; +---- +12 +11 + +statement ok +drop table t1 + +# Reproducer for https://github.com/apache/arrow-datafusion/issues/8175 + +statement ok +create table t1(state string, city string, min_temp float, area int, time timestamp) as values + ('MA', 'Boston', 70.4, 1, 50), + ('MA', 'Bedford', 71.59, 2, 150); + +query RI +select date_part('year', time) as bla, count(distinct state) as count from t1 group by bla; +---- +1970 1 + +query PI +select date_bin(interval '1 year', time) as bla, count(distinct state) as count from t1 group by bla; +---- +1970-01-01T00:00:00 1 + +statement ok +drop table t1 + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 INT UNSIGNED NOT NULL, + c10 BIGINT UNSIGNED NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +query TIIII +SELECT c1, count(distinct c2), min(distinct c2), min(c3), max(c4) FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +a 5 1 -101 32064 +b 5 1 -117 25286 +c 5 1 -117 29106 +d 5 1 -99 31106 +e 5 1 -95 32514 + +query TT +EXPLAIN SELECT c1, count(distinct c2), min(distinct c2), sum(c3), max(c4) FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +logical_plan +Sort: aggregate_test_100.c1 ASC NULLS LAST +--Projection: aggregate_test_100.c1, COUNT(alias1) AS COUNT(DISTINCT aggregate_test_100.c2), MIN(alias1) AS MIN(DISTINCT aggregate_test_100.c2), SUM(alias2) AS SUM(aggregate_test_100.c3), MAX(alias3) AS MAX(aggregate_test_100.c4) +----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)]] +------Aggregate: groupBy=[[aggregate_test_100.c1, aggregate_test_100.c2 AS alias1]], aggr=[[SUM(CAST(aggregate_test_100.c3 AS Int64)) AS alias2, MAX(aggregate_test_100.c4) AS alias3]] +--------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4] +physical_plan +SortPreservingMergeExec: [c1@0 ASC NULLS LAST] +--SortExec: expr=[c1@0 ASC NULLS LAST] +----ProjectionExec: expr=[c1@0 as c1, COUNT(alias1)@1 as COUNT(DISTINCT aggregate_test_100.c2), MIN(alias1)@2 as MIN(DISTINCT aggregate_test_100.c2), SUM(alias2)@3 as SUM(aggregate_test_100.c3), MAX(alias3)@4 as MAX(aggregate_test_100.c4)] +------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)] +--------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1, alias1@1 as alias1], aggr=[alias2, alias3] +----------------CoalesceBatchesExec: target_batch_size=2 +------------------RepartitionExec: partitioning=Hash([c1@0, alias1@1], 8), input_partitions=8 +--------------------AggregateExec: mode=Partial, gby=[c1@0 as c1, c2@1 as alias1], aggr=[alias2, alias3] +----------------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4], has_header=true + +# Use PostgreSQL dialect +statement ok +set datafusion.sql_parser.dialect = 'Postgres'; + +query II +SELECT c2, count(distinct c3) FILTER (WHERE c1 != 'a') FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 17 +2 17 +3 13 +4 19 +5 11 + +query III +SELECT c2, count(distinct c3) FILTER (WHERE c1 != 'a'), count(c5) FILTER (WHERE c1 != 'b') FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 17 19 +2 17 18 +3 13 17 +4 19 18 +5 11 9 + +# Restore the default dialect +statement ok +set datafusion.sql_parser.dialect = 'Generic'; + +statement ok +drop table aggregate_test_100; + + +# Create an unbounded external table with primary key +# column c +statement ok +CREATE EXTERNAL TABLE unbounded_multiple_ordered_table_with_pk ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER primary key, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# Query below can be executed, since c is primary key. +query III rowsort +SELECT c, a, SUM(d) +FROM unbounded_multiple_ordered_table_with_pk +GROUP BY c +ORDER BY c +LIMIT 5 +---- +0 0 0 +1 0 2 +2 0 0 +3 0 0 +4 0 1 + + +query ITIPTR rowsort +SELECT r.* +FROM sales_global_with_pk as l, sales_global_with_pk as r +LIMIT 5 +---- +0 GRC 0 2022-01-01T06:00:00 EUR 30 +1 FRA 1 2022-01-01T08:00:00 EUR 50 +1 FRA 3 2022-01-02T12:00:00 EUR 200 +1 TUR 2 2022-01-01T11:30:00 TRY 75 +1 TUR 4 2022-01-03T10:00:00 TRY 100 diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index f90901021637..1b5ad86546a3 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -150,13 +150,16 @@ datafusion.execution.aggregate.scalar_update_factor 10 datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true datafusion.execution.collect_statistics false +datafusion.execution.listing_table_ignore_subdirectory true +datafusion.execution.max_buffered_batches_per_output_file 2 datafusion.execution.meta_fetch_concurrency 32 -datafusion.execution.parquet.allow_single_file_parallelism false +datafusion.execution.minimum_parallel_output_files 4 +datafusion.execution.parquet.allow_single_file_parallelism true datafusion.execution.parquet.bloom_filter_enabled false datafusion.execution.parquet.bloom_filter_fpp NULL datafusion.execution.parquet.bloom_filter_ndv NULL datafusion.execution.parquet.column_index_truncate_length NULL -datafusion.execution.parquet.compression NULL +datafusion.execution.parquet.compression zstd(3) datafusion.execution.parquet.created_by datafusion datafusion.execution.parquet.data_page_row_count_limit 18446744073709551615 datafusion.execution.parquet.data_pagesize_limit 1048576 @@ -166,6 +169,8 @@ datafusion.execution.parquet.enable_page_index true datafusion.execution.parquet.encoding NULL datafusion.execution.parquet.max_row_group_size 1048576 datafusion.execution.parquet.max_statistics_size NULL +datafusion.execution.parquet.maximum_buffered_record_batches_per_stream 2 +datafusion.execution.parquet.maximum_parallel_row_group_writers 1 datafusion.execution.parquet.metadata_size_hint NULL datafusion.execution.parquet.pruning true datafusion.execution.parquet.pushdown_filters false @@ -175,6 +180,7 @@ datafusion.execution.parquet.statistics_enabled NULL datafusion.execution.parquet.write_batch_size 1024 datafusion.execution.parquet.writer_version 1.0 datafusion.execution.planning_concurrency 13 +datafusion.execution.soft_max_rows_per_output_file 50000000 datafusion.execution.sort_in_place_threshold_bytes 1048576 datafusion.execution.sort_spill_reservation_bytes 10485760 datafusion.execution.target_partitions 7 @@ -183,12 +189,14 @@ datafusion.explain.logical_plan_only false datafusion.explain.physical_plan_only false datafusion.explain.show_statistics false datafusion.optimizer.allow_symmetric_joins_without_pruning true -datafusion.optimizer.bounded_order_preserving_variants false +datafusion.optimizer.default_filter_selectivity 20 +datafusion.optimizer.enable_distinct_aggregation_soft_limit true datafusion.optimizer.enable_round_robin_repartition true datafusion.optimizer.enable_topk_aggregation true datafusion.optimizer.filter_null_join_keys false datafusion.optimizer.hash_join_single_partition_threshold 1048576 datafusion.optimizer.max_passes 3 +datafusion.optimizer.prefer_existing_sort false datafusion.optimizer.prefer_hash_join true datafusion.optimizer.repartition_aggregations true datafusion.optimizer.repartition_file_min_size 10485760 @@ -202,12 +210,93 @@ datafusion.sql_parser.dialect generic datafusion.sql_parser.enable_ident_normalization true datafusion.sql_parser.parse_float_as_decimal false +# show all variables with verbose +query TTT rowsort +SHOW ALL VERBOSE +---- +datafusion.catalog.create_default_catalog_and_schema true Whether the default catalog and schema should be created automatically. +datafusion.catalog.default_catalog datafusion The default catalog name - this impacts what SQL queries use if not specified +datafusion.catalog.default_schema public The default schema name - this impacts what SQL queries use if not specified +datafusion.catalog.format NULL Type of `TableProvider` to use when loading `default` schema +datafusion.catalog.has_header false If the file has a header +datafusion.catalog.information_schema true Should DataFusion provide access to `information_schema` virtual tables for displaying schema information +datafusion.catalog.location NULL Location scanned to load tables for `default` schema +datafusion.execution.aggregate.scalar_update_factor 10 Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. +datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption +datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting +datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files +datafusion.execution.listing_table_ignore_subdirectory true Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). +datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption +datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics +datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. +datafusion.execution.parquet.allow_single_file_parallelism true Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. +datafusion.execution.parquet.bloom_filter_enabled false Sets if bloom filter is enabled for any column +datafusion.execution.parquet.bloom_filter_fpp NULL Sets bloom filter false positive probability. If NULL, uses default parquet writer setting +datafusion.execution.parquet.bloom_filter_ndv NULL Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting +datafusion.execution.parquet.column_index_truncate_length NULL Sets column index truncate length +datafusion.execution.parquet.compression zstd(3) Sets default parquet compression codec Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting +datafusion.execution.parquet.created_by datafusion Sets "created by" property +datafusion.execution.parquet.data_page_row_count_limit 18446744073709551615 Sets best effort maximum number of rows in data page +datafusion.execution.parquet.data_pagesize_limit 1048576 Sets best effort maximum size of data page in bytes +datafusion.execution.parquet.dictionary_enabled NULL Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting +datafusion.execution.parquet.dictionary_page_size_limit 1048576 Sets best effort maximum dictionary page size, in bytes +datafusion.execution.parquet.enable_page_index true If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. +datafusion.execution.parquet.encoding NULL Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting +datafusion.execution.parquet.max_row_group_size 1048576 Sets maximum number of rows in a row group +datafusion.execution.parquet.max_statistics_size NULL Sets max statistics size for any column. If NULL, uses default parquet writer setting +datafusion.execution.parquet.maximum_buffered_record_batches_per_stream 2 By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. +datafusion.execution.parquet.maximum_parallel_row_group_writers 1 By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. +datafusion.execution.parquet.metadata_size_hint NULL If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer +datafusion.execution.parquet.pruning true If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file +datafusion.execution.parquet.pushdown_filters false If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded +datafusion.execution.parquet.reorder_filters false If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query +datafusion.execution.parquet.skip_metadata true If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata +datafusion.execution.parquet.statistics_enabled NULL Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting +datafusion.execution.parquet.write_batch_size 1024 Sets write_batch_size in bytes +datafusion.execution.parquet.writer_version 1.0 Sets parquet writer version valid values are "1.0" and "2.0" +datafusion.execution.planning_concurrency 13 Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system +datafusion.execution.soft_max_rows_per_output_file 50000000 Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max +datafusion.execution.sort_in_place_threshold_bytes 1048576 When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. +datafusion.execution.sort_spill_reservation_bytes 10485760 Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). +datafusion.execution.target_partitions 7 Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system +datafusion.execution.time_zone +00:00 The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour +datafusion.explain.logical_plan_only false When set to true, the explain statement will only print logical plans +datafusion.explain.physical_plan_only false When set to true, the explain statement will only print physical plans +datafusion.explain.show_statistics false When set to true, the explain statement will print operator statistics for physical plans +datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. +datafusion.optimizer.default_filter_selectivity 20 The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). +datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. +datafusion.optimizer.enable_round_robin_repartition true When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores +datafusion.optimizer.enable_topk_aggregation true When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible +datafusion.optimizer.filter_null_join_keys false When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. +datafusion.optimizer.hash_join_single_partition_threshold 1048576 The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition +datafusion.optimizer.max_passes 3 Number of times that the optimizer will attempt to optimize the plan +datafusion.optimizer.prefer_existing_sort false When true, DataFusion will opportunistically remove sorts when the data is already sorted, (i.e. setting `preserve_order` to true on `RepartitionExec` and using `SortPreservingMergeExec`) When false, DataFusion will maximize plan parallelism using `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. +datafusion.optimizer.prefer_hash_join true When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory +datafusion.optimizer.repartition_aggregations true Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level +datafusion.optimizer.repartition_file_min_size 10485760 Minimum total files size in bytes to perform file scan repartitioning. +datafusion.optimizer.repartition_file_scans true When set to `true`, file groups will be repartitioned to achieve maximum parallelism. Currently Parquet and CSV formats are supported. If set to `true`, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false`, different files will be read in parallel, but repartitioning won't happen within a single file. +datafusion.optimizer.repartition_joins true Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level +datafusion.optimizer.repartition_sorts true Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below ```text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ``` would turn into the plan below which performs better in multithreaded environments ```text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ``` +datafusion.optimizer.repartition_windows true Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level +datafusion.optimizer.skip_failed_rules false When set to true, the logical plan optimizer will produce warning messages if any optimization rules produce errors and then proceed to the next rule. When set to false, any rules that produce errors will cause the query to fail +datafusion.optimizer.top_down_join_key_reordering true When set to true, the physical plan optimizer will run a top down process to reorder the join keys +datafusion.sql_parser.dialect generic Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. +datafusion.sql_parser.enable_ident_normalization true When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) +datafusion.sql_parser.parse_float_as_decimal false When set to true, SQL parser will parse float as decimal type + # show_variable_in_config_options query TT SHOW datafusion.execution.batch_size ---- datafusion.execution.batch_size 8192 +# show_variable_in_config_options_verbose +query TTT +SHOW datafusion.execution.batch_size VERBOSE +---- +datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption + # show_time_zone_default_utc # https://github.com/apache/arrow-datafusion/issues/3255 query TT @@ -223,6 +312,26 @@ SHOW TIMEZONE datafusion.execution.time_zone +00:00 +# show_time_zone_default_utc_verbose +# https://github.com/apache/arrow-datafusion/issues/3255 +query TTT +SHOW TIME ZONE VERBOSE +---- +datafusion.execution.time_zone +00:00 The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour + +# show_timezone_default_utc +# https://github.com/apache/arrow-datafusion/issues/3255 +query TTT +SHOW TIMEZONE VERBOSE +---- +datafusion.execution.time_zone +00:00 The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour + + +# show empty verbose +query TTT +SHOW VERBOSE +---- + # information_schema_describe_table ## some_table @@ -372,6 +481,9 @@ set datafusion.catalog.information_schema = false; statement error Error during planning: SHOW \[VARIABLE\] is not supported unless information_schema is enabled SHOW SOMETHING +statement error Error during planning: SHOW \[VARIABLE\] is not supported unless information_schema is enabled +SHOW SOMETHING VERBOSE + statement ok set datafusion.catalog.information_schema = true; diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index 74968bb089d7..e20b3779459b 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -64,7 +64,7 @@ Dml: op=[Insert Into] table=[table_without_values] --------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -InsertExec: sink=MemoryTable (partitions=1) +FileSinkExec: sink=MemoryTable (partitions=1) --ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] ----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] ------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] @@ -125,14 +125,15 @@ Dml: op=[Insert Into] table=[table_without_values] ----WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -InsertExec: sink=MemoryTable (partitions=1) ---ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] -----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] -------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 -------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ---------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true +FileSinkExec: sink=MemoryTable (partitions=1) +--CoalescePartitionsExec +----ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] +------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +--------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] +----------CoalesceBatchesExec: target_batch_size=8192 +------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 +--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true @@ -174,7 +175,7 @@ Dml: op=[Insert Into] table=[table_without_values] --------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -InsertExec: sink=MemoryTable (partitions=8) +FileSinkExec: sink=MemoryTable (partitions=8) --ProjectionExec: expr=[a1@0 as a1, a2@1 as a2] ----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] ------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as a1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as a2, c1@0 as c1] @@ -216,10 +217,9 @@ Dml: op=[Insert Into] table=[table_without_values] ----Sort: aggregate_test_100.c1 ASC NULLS LAST ------TableScan: aggregate_test_100 projection=[c1] physical_plan -InsertExec: sink=MemoryTable (partitions=1) ---ProjectionExec: expr=[c1@0 as c1] -----SortExec: expr=[c1@0 ASC NULLS LAST] -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true +FileSinkExec: sink=MemoryTable (partitions=1) +--SortExec: expr=[c1@0 ASC NULLS LAST] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true query T insert into table_without_values select c1 from aggregate_test_100 order by c1; @@ -258,14 +258,18 @@ insert into table_without_values(name, id) values(4, 'zoo'); statement error Error during planning: Column count doesn't match insert query! insert into table_without_values(id) values(4, 'zoo'); -statement error Error during planning: Inserting query must have the same schema with the table. +# insert NULL values for the missing column (name) +query IT insert into table_without_values(id) values(4); +---- +1 query IT rowsort select * from table_without_values; ---- 1 foo 2 bar +4 NULL statement ok drop table table_without_values; @@ -285,6 +289,16 @@ insert into table_without_values values(2, NULL); ---- 1 +# insert NULL values for the missing column (field2) +query II +insert into table_without_values(field1) values(3); +---- +1 + +# insert NULL values for the missing column (field1), but column is non-nullable +statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +insert into table_without_values(field2) values(300); + statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable insert into table_without_values values(NULL, 300); @@ -296,6 +310,126 @@ select * from table_without_values; ---- 1 100 2 NULL +3 NULL statement ok drop table table_without_values; + + +### Test for creating tables into directories that do not already exist +# note use of `scratch` directory (which is cleared between runs) + +statement ok +create external table new_empty_table(x int) stored as parquet location 'test_files/scratch/insert/new_empty_table/'; -- needs trailing slash + +# should start empty +query I +select * from new_empty_table; +---- + +# should succeed and the table should create the direectory +statement ok +insert into new_empty_table values (1); + +# Now has values +query I +select * from new_empty_table; +---- +1 + +statement ok +drop table new_empty_table; + +## test we get an error if the path doesn't end in slash +statement ok +create external table bad_new_empty_table(x int) stored as parquet location 'test_files/scratch/insert/bad_new_empty_table'; -- no trailing slash + +# should fail +query error DataFusion error: Error during planning: Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`\. To append to an existing file use StreamTable, e\.g\. by using CREATE UNBOUNDED EXTERNAL TABLE +insert into bad_new_empty_table values (1); + +statement ok +drop table bad_new_empty_table; + + +### Test for specifying column's default value + +statement ok +create table test_column_defaults( + a int, + b int not null default null, + c int default 100*2+300, + d text default lower('DEFAULT_TEXT'), + e timestamp default now() +) + +query IIITP +insert into test_column_defaults values(1, 10, 100, 'ABC', now()) +---- +1 + +statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable +insert into test_column_defaults(a) values(2) + +query IIITP +insert into test_column_defaults(b) values(20) +---- +1 + +query IIIT rowsort +select a,b,c,d from test_column_defaults +---- +1 10 100 ABC +NULL 20 500 default_text + +# fill the timestamp column with default value `now()` again, it should be different from the previous one +query IIITP +insert into test_column_defaults(a, b, c, d) values(2, 20, 200, 'DEF') +---- +1 + +# Ensure that the default expression `now()` is evaluated during insertion, not optimized away. +# Rows are inserted during different time, so their timestamp values should be different. +query I rowsort +select count(distinct e) from test_column_defaults +---- +3 + +# Expect all rows to be true as now() was inserted into the table +query B rowsort +select e < now() from test_column_defaults +---- +true +true +true + +statement ok +drop table test_column_defaults + + +# test create table as +statement ok +create table test_column_defaults( + a int, + b int not null default null, + c int default 100*2+300, + d text default lower('DEFAULT_TEXT'), + e timestamp default now() +) as values(1, 10, 100, 'ABC', now()) + +query IIITP +insert into test_column_defaults(b) values(20) +---- +1 + +query IIIT rowsort +select a,b,c,d from test_column_defaults +---- +1 10 100 ABC +NULL 20 500 default_text + +statement ok +drop table test_column_defaults + +statement error DataFusion error: Error during planning: Column reference is not allowed in the DEFAULT expression : Schema error: No field named a. +create table test_column_defaults(a int, b int default a+1) diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index a29c230a466e..e73778ad44e5 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -40,13 +40,267 @@ STORED AS CSV WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' -# test_insert_into +statement ok +create table dictionary_encoded_values as values +('a', arrow_cast('foo', 'Dictionary(Int32, Utf8)')), ('b', arrow_cast('bar', 'Dictionary(Int32, Utf8)')); + +query TTT +describe dictionary_encoded_values; +---- +column1 Utf8 YES +column2 Dictionary(Int32, Utf8) YES + +statement ok +CREATE EXTERNAL TABLE dictionary_encoded_parquet_partitioned( + a varchar, + b varchar, +) +STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/parquet_types_partitioned/' +PARTITIONED BY (b) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query TT +insert into dictionary_encoded_parquet_partitioned +select * from dictionary_encoded_values +---- +2 + +query TT +select * from dictionary_encoded_parquet_partitioned order by (a); +---- +a foo +b bar + +statement ok +CREATE EXTERNAL TABLE dictionary_encoded_arrow_partitioned( + a varchar, + b varchar, +) +STORED AS arrow +LOCATION 'test_files/scratch/insert_to_external/arrow_dict_partitioned/' +PARTITIONED BY (b) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query TT +insert into dictionary_encoded_arrow_partitioned +select * from dictionary_encoded_values +---- +2 + +statement ok +CREATE EXTERNAL TABLE dictionary_encoded_arrow_test_readback( + a varchar, +) +STORED AS arrow +LOCATION 'test_files/scratch/insert_to_external/arrow_dict_partitioned/b=bar/' +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query T +select * from dictionary_encoded_arrow_test_readback; +---- +b + +# https://github.com/apache/arrow-datafusion/issues/7816 +query error DataFusion error: Arrow error: Schema error: project index 1 out of bounds, max field 1 +select * from dictionary_encoded_arrow_partitioned order by (a); + + +# test_insert_into statement ok set datafusion.execution.target_partitions = 8; statement ok CREATE EXTERNAL TABLE +ordered_insert_test(a bigint, b bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/insert_to_ordered/' +WITH ORDER (a ASC, B DESC) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query TT +EXPLAIN INSERT INTO ordered_insert_test values (5, 1), (4, 2), (7,7), (7,8), (7,9), (7,10), (3, 3), (2, 4), (1, 5); +---- +logical_plan +Dml: op=[Insert Into] table=[ordered_insert_test] +--Projection: column1 AS a, column2 AS b +----Values: (Int64(5), Int64(1)), (Int64(4), Int64(2)), (Int64(7), Int64(7)), (Int64(7), Int64(8)), (Int64(7), Int64(9))... +physical_plan +FileSinkExec: sink=CsvSink(file_groups=[]) +--SortExec: expr=[a@0 ASC NULLS LAST,b@1 DESC] +----ProjectionExec: expr=[column1@0 as a, column2@1 as b] +------ValuesExec + +query II +INSERT INTO ordered_insert_test values (5, 1), (4, 2), (7,7), (7,8), (7,9), (7,10), (3, 3), (2, 4), (1, 5); +---- +9 + +query II +SELECT * from ordered_insert_test; +---- +1 5 +2 4 +3 3 +4 2 +5 1 +7 10 +7 9 +7 8 +7 7 + +# test partitioned insert + +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test(a string, b string, c bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned/' +PARTITIONED BY (a, b) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +#note that partitioned cols are moved to the end so value tuples are (c, a, b) +query ITT +INSERT INTO partitioned_insert_test values (1, 10, 100), (1, 10, 200), (1, 20, 100), (1, 20, 200), (2, 20, 100), (2, 20, 200); +---- +6 + +query ITT +select * from partitioned_insert_test order by a,b,c +---- +1 10 100 +1 10 200 +1 20 100 +2 20 100 +1 20 200 +2 20 200 + +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test_verify(c bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned/a=20/b=100/' +OPTIONS( +insert_mode 'append_new_files', +); + +query I +select * from partitioned_insert_test_verify; +---- +1 +2 + +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test_json(a string, b string) +STORED AS json +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned_json/' +PARTITIONED BY (a) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query TT +INSERT INTO partitioned_insert_test_json values (1, 2), (3, 4), (5, 6), (1, 2), (3, 4), (5, 6); +---- +6 + +# Issue open for this error: https://github.com/apache/arrow-datafusion/issues/7816 +query error DataFusion error: Arrow error: Json error: Encountered unmasked nulls in non\-nullable StructArray child: Field \{ name: "a", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: \{\} \} +select * from partitioned_insert_test_json order by a,b + +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test_verify_json(b string) +STORED AS json +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned_json/a=2/' +OPTIONS( +insert_mode 'append_new_files', +); + +query T +select * from partitioned_insert_test_verify_json; +---- +1 +1 + +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test_pq(a string, b bigint) +STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned_pq/' +PARTITIONED BY (a) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query IT +INSERT INTO partitioned_insert_test_pq values (1, 2), (3, 4), (5, 6), (1, 2), (3, 4), (5, 6); +---- +6 + +query IT +select * from partitioned_insert_test_pq order by a ASC, b ASC +---- +1 2 +1 2 +3 4 +3 4 +5 6 +5 6 + +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test_verify_pq(b bigint) +STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned_pq/a=2/' +OPTIONS( +insert_mode 'append_new_files', +); + +query I +select * from partitioned_insert_test_verify_pq; +---- +1 +1 + + +statement ok +CREATE EXTERNAL TABLE +single_file_test(a bigint, b bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/single_csv_table.csv' +OPTIONS( +create_local_path 'true', +single_file 'true', +); + +query error DataFusion error: Error during planning: Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`\. To append to an existing file use StreamTable, e\.g\. by using CREATE UNBOUNDED EXTERNAL TABLE +INSERT INTO single_file_test values (1, 2), (3, 4); + +statement ok +drop table single_file_test; + +statement ok +CREATE UNBOUNDED EXTERNAL TABLE single_file_test(a bigint, b bigint) STORED AS csv LOCATION 'test_files/scratch/insert_to_external/single_csv_table.csv' @@ -60,17 +314,24 @@ INSERT INTO single_file_test values (1, 2), (3, 4); ---- 2 +query II +INSERT INTO single_file_test values (4, 5), (6, 7); +---- +2 + query II select * from single_file_test; ---- 1 2 3 4 +4 5 +6 7 statement ok CREATE EXTERNAL TABLE directory_test(a bigint, b bigint) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q0' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q0/' OPTIONS( create_local_path 'true', ); @@ -90,7 +351,7 @@ statement ok CREATE EXTERNAL TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q1' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q1/' OPTIONS (create_local_path 'true'); query TT @@ -109,7 +370,7 @@ Dml: op=[Insert Into] table=[table_without_values] --------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -InsertExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) +FileSinkExec: sink=ParquetSink(file_groups=[]) --ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] ----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] ------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] @@ -156,7 +417,7 @@ statement ok CREATE EXTERNAL TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q2' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q2/' OPTIONS (create_local_path 'true'); query TT @@ -172,14 +433,15 @@ Dml: op=[Insert Into] table=[table_without_values] ----WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -InsertExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) ---ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] -----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] -------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 -------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ---------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true +FileSinkExec: sink=ParquetSink(file_groups=[]) +--CoalescePartitionsExec +----ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] +------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +--------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] +----------CoalesceBatchesExec: target_batch_size=8192 +------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 +--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true @@ -200,7 +462,7 @@ statement ok CREATE EXTERNAL TABLE table_without_values(c1 varchar NULL) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q3' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q3/' OPTIONS (create_local_path 'true'); # verify that the sort order of the insert query is maintained into the @@ -215,10 +477,9 @@ Dml: op=[Insert Into] table=[table_without_values] ----Sort: aggregate_test_100.c1 ASC NULLS LAST ------TableScan: aggregate_test_100 projection=[c1] physical_plan -InsertExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) ---ProjectionExec: expr=[c1@0 as c1] -----SortExec: expr=[c1@0 ASC NULLS LAST] -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true +FileSinkExec: sink=ParquetSink(file_groups=[]) +--SortExec: expr=[c1@0 ASC NULLS LAST] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true query T insert into table_without_values select c1 from aggregate_test_100 order by c1; @@ -240,7 +501,7 @@ statement ok CREATE EXTERNAL TABLE table_without_values(id BIGINT, name varchar) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q4' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q4/' OPTIONS (create_local_path 'true'); query IT @@ -262,14 +523,18 @@ insert into table_without_values(name, id) values(4, 'zoo'); statement error Error during planning: Column count doesn't match insert query! insert into table_without_values(id) values(4, 'zoo'); -statement error Error during planning: Inserting query must have the same schema with the table. +# insert NULL values for the missing column (name) +query IT insert into table_without_values(id) values(4); +---- +1 query IT rowsort select * from table_without_values; ---- 1 foo 2 bar +4 NULL statement ok drop table table_without_values; @@ -279,7 +544,7 @@ statement ok CREATE EXTERNAL TABLE table_without_values(field1 BIGINT NOT NULL, field2 BIGINT NULL) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q5' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q5/' OPTIONS (create_local_path 'true'); query II @@ -292,6 +557,16 @@ insert into table_without_values values(2, NULL); ---- 1 +# insert NULL values for the missing column (field2) +query II +insert into table_without_values(field1) values(3); +---- +1 + +# insert NULL values for the missing column (field1), but column is non-nullable +statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +insert into table_without_values(field2) values(300); + statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable insert into table_without_values values(NULL, 300); @@ -303,6 +578,74 @@ select * from table_without_values; ---- 1 100 2 NULL +3 NULL statement ok drop table table_without_values; + + +### Test for specifying column's default value + +statement ok +CREATE EXTERNAL TABLE test_column_defaults( + a int, + b int not null default null, + c int default 100*2+300, + d text default lower('DEFAULT_TEXT'), + e timestamp default now() +) STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q6/' +OPTIONS (create_local_path 'true'); + +# fill in all column values +query IIITP +insert into test_column_defaults values(1, 10, 100, 'ABC', now()) +---- +1 + +statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable +insert into test_column_defaults(a) values(2) + +query IIITP +insert into test_column_defaults(b) values(20) +---- +1 + +query IIIT rowsort +select a,b,c,d from test_column_defaults +---- +1 10 100 ABC +NULL 20 500 default_text + +# fill the timestamp column with default value `now()` again, it should be different from the previous one +query IIITP +insert into test_column_defaults(a, b, c, d) values(2, 20, 200, 'DEF') +---- +1 + +# Ensure that the default expression `now()` is evaluated during insertion, not optimized away. +# Rows are inserted during different time, so their timestamp values should be different. +query I rowsort +select count(distinct e) from test_column_defaults +---- +3 + +# Expect all rows to be true as now() was inserted into the table +query B rowsort +select e < now() from test_column_defaults +---- +true +true +true + +statement ok +drop table test_column_defaults + +# test invalid default value +statement error DataFusion error: Error during planning: Column reference is not allowed in the DEFAULT expression : Schema error: No field named a. +CREATE EXTERNAL TABLE test_column_defaults( + a int, + b int default a+1 +) STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q7/' +OPTIONS (create_local_path 'true'); diff --git a/datafusion/sqllogictest/test_files/interval.slt b/datafusion/sqllogictest/test_files/interval.slt index 500876f76221..f2ae2984f07b 100644 --- a/datafusion/sqllogictest/test_files/interval.slt +++ b/datafusion/sqllogictest/test_files/interval.slt @@ -126,6 +126,86 @@ select interval '5' nanoseconds ---- 0 years 0 mons 0 days 0 hours 0 mins 0.000000005 secs +query ? +select interval '5 YEAR' +---- +0 years 60 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +select interval '5 MONTH' +---- +0 years 5 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +select interval '5 WEEK' +---- +0 years 0 mons 35 days 0 hours 0 mins 0.000000000 secs + +query ? +select interval '5 DAY' +---- +0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs + +query ? +select interval '5 HOUR' +---- +0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs + +query ? +select interval '5 HOURS' +---- +0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs + +query ? +select interval '5 MINUTE' +---- +0 years 0 mons 0 days 0 hours 5 mins 0.000000000 secs + +query ? +select interval '5 SECOND' +---- +0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs + +query ? +select interval '5 SECONDS' +---- +0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs + +query ? +select interval '5 MILLISECOND' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.005000000 secs + +query ? +select interval '5 MILLISECONDS' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.005000000 secs + +query ? +select interval '5 MICROSECOND' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.000005000 secs + +query ? +select interval '5 MICROSECONDS' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.000005000 secs + +query ? +select interval '5 NANOSECOND' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.000000005 secs + +query ? +select interval '5 NANOSECONDS' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.000000005 secs + +query ? +select interval '5 YEAR 5 MONTH 5 DAY 5 HOUR 5 MINUTE 5 SECOND 5 MILLISECOND 5 MICROSECOND 5 NANOSECOND' +---- +0 years 65 mons 5 days 5 hours 5 mins 5.005005005 secs + # Interval with string literal addition query ? select interval '1 month' + '1 month' diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 283ff57a984c..ca9b918ff3ee 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -556,7 +556,11 @@ query TT explain select * from t1 join t2 on false; ---- logical_plan EmptyRelation -physical_plan EmptyExec: produce_one_row=false +physical_plan EmptyExec + +# Make batch size smaller than table row number. to introduce parallelism to the plan. +statement ok +set datafusion.execution.batch_size = 1; # test covert inner join to cross join when condition is true query TT @@ -568,9 +572,9 @@ CrossJoin: --TableScan: t2 projection=[t2_id, t2_name, t2_int] physical_plan CrossJoinExec ---CoalescePartitionsExec -----MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +--MemoryExec: partitions=1, partition_sizes=[1] +--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----MemoryExec: partitions=1, partition_sizes=[1] statement ok drop table IF EXISTS t1; @@ -590,3 +594,72 @@ drop table IF EXISTS full_join_test; # batch size statement ok set datafusion.execution.batch_size = 8192; + +# related to: https://github.com/apache/arrow-datafusion/issues/8374 +statement ok +CREATE TABLE t1(a text, b int) AS VALUES ('Alice', 50), ('Alice', 100); + +statement ok +CREATE TABLE t2(a text, b int) AS VALUES ('Alice', 2), ('Alice', 1); + +# the current query results are incorrect, becuase the query was incorrectly rewritten as: +# SELECT t1.a, t1.b FROM t1 JOIN t2 ON t1.a = t2.a ORDER BY t1.a, t1.b; +# the difference is ORDER BY clause rewrite from t2.b to t1.b, it is incorrect. +# after https://github.com/apache/arrow-datafusion/issues/8374 fixed, the correct result should be: +# Alice 50 +# Alice 100 +# Alice 50 +# Alice 100 +query TI +SELECT t1.a, t1.b FROM t1 JOIN t2 ON t1.a = t2.a ORDER BY t1.a, t2.b; +---- +Alice 50 +Alice 50 +Alice 100 +Alice 100 + +query TITI +SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a ORDER BY t1.a, t2.b; +---- +Alice 50 Alice 1 +Alice 100 Alice 1 +Alice 50 Alice 2 +Alice 100 Alice 2 + +statement ok +set datafusion.execution.target_partitions = 1; + +statement ok +set datafusion.optimizer.repartition_joins = true; + +# make sure when target partition is 1, hash repartition is not added +# to the final plan. +query TT +EXPLAIN SELECT * +FROM t1, +t1 as t2 +WHERE t1.a=t2.a; +---- +logical_plan +Inner Join: t1.a = t2.a +--TableScan: t1 projection=[a, b] +--SubqueryAlias: t2 +----TableScan: t1 projection=[a, b] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0)] +----MemoryExec: partitions=1, partition_sizes=[1] +----MemoryExec: partitions=1, partition_sizes=[1] + +# Reset the configs to old values +statement ok +set datafusion.execution.target_partitions = 4; + +statement ok +set datafusion.optimizer.repartition_joins = false; + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 01c0131fdb62..a7146a5a91c4 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -24,7 +24,7 @@ statement ok set datafusion.execution.target_partitions = 2; statement ok -set datafusion.execution.batch_size = 4096; +set datafusion.execution.batch_size = 2; statement ok set datafusion.explain.logical_plan_only = true; @@ -140,6 +140,17 @@ SELECT FROM test_timestamps_table_source; +# create a table of timestamps with time zone +statement ok +CREATE TABLE test_timestamps_tz_table as +SELECT + arrow_cast(ts::timestamp::bigint, 'Timestamp(Nanosecond, Some("UTC"))') as nanos, + arrow_cast(ts::timestamp::bigint / 1000, 'Timestamp(Microsecond, Some("UTC"))') as micros, + arrow_cast(ts::timestamp::bigint / 1000000, 'Timestamp(Millisecond, Some("UTC"))') as millis, + arrow_cast(ts::timestamp::bigint / 1000000000, 'Timestamp(Second, Some("UTC"))') as secs, + names +FROM + test_timestamps_table_source; statement ok @@ -185,6 +196,10 @@ FROM statement ok set datafusion.execution.target_partitions = 2; +# make sure to a batch size smaller than row number of the table. +statement ok +set datafusion.execution.batch_size = 2; + ########## ## Joins Tests ########## @@ -1311,13 +1326,13 @@ Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[]] physical_plan AggregateExec: mode=SinglePartitioned, gby=[t1_id@0 as t1_id], aggr=[] --ProjectionExec: expr=[t1_id@0 as t1_id] -----CoalesceBatchesExec: target_batch_size=4096 +----CoalesceBatchesExec: target_batch_size=2 ------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)] ---------CoalesceBatchesExec: target_batch_size=4096 +--------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] ---------CoalesceBatchesExec: target_batch_size=4096 +--------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] @@ -1339,13 +1354,13 @@ physical_plan ProjectionExec: expr=[COUNT(*)@1 as COUNT(*)] --AggregateExec: mode=SinglePartitioned, gby=[t1_id@0 as t1_id], aggr=[COUNT(*)] ----ProjectionExec: expr=[t1_id@0 as t1_id] -------CoalesceBatchesExec: target_batch_size=4096 +------CoalesceBatchesExec: target_batch_size=2 --------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)] -----------CoalesceBatchesExec: target_batch_size=4096 +----------CoalesceBatchesExec: target_batch_size=2 ------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 --------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------------MemoryExec: partitions=1, partition_sizes=[1] -----------CoalesceBatchesExec: target_batch_size=4096 +----------CoalesceBatchesExec: target_batch_size=2 ------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 --------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------------MemoryExec: partitions=1, partition_sizes=[1] @@ -1372,13 +1387,13 @@ ProjectionExec: expr=[COUNT(alias1)@0 as COUNT(DISTINCT join_t1.t1_id)] --------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[] ----------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[] ------------ProjectionExec: expr=[t1_id@0 as t1_id] ---------------CoalesceBatchesExec: target_batch_size=4096 +--------------CoalesceBatchesExec: target_batch_size=2 ----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)] -------------------CoalesceBatchesExec: target_batch_size=4096 +------------------CoalesceBatchesExec: target_batch_size=2 --------------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------------------MemoryExec: partitions=1, partition_sizes=[1] -------------------CoalesceBatchesExec: target_batch_size=4096 +------------------CoalesceBatchesExec: target_batch_size=2 --------------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 ----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------------------MemoryExec: partitions=1, partition_sizes=[1] @@ -1439,17 +1454,16 @@ Projection: join_t1.t1_id, join_t1.t1_name, join_t1.t1_int, join_t2.t2_id, join_ ----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] ----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int] -----CoalesceBatchesExec: target_batch_size=4096 -------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t1.t1_id + Int64(11)@3, CAST(join_t2.t2_id AS Int64)@3)] ---------CoalescePartitionsExec -----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] -------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ---------------MemoryExec: partitions=1, partition_sizes=[1] ---------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(join_t2.t2_id AS Int64)] +ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t1.t1_id + Int64(11)@3, CAST(join_t2.t2_id AS Int64)@3)] +------CoalescePartitionsExec +--------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] ----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(join_t2.t2_id AS Int64)] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.optimizer.repartition_joins = true; @@ -1466,20 +1480,19 @@ Projection: join_t1.t1_id, join_t1.t1_name, join_t1.t1_int, join_t2.t2_id, join_ ----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] ----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int] -----CoalesceBatchesExec: target_batch_size=4096 -------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t1.t1_id + Int64(11)@3, CAST(join_t2.t2_id AS Int64)@3)] ---------CoalesceBatchesExec: target_batch_size=4096 -----------RepartitionExec: partitioning=Hash([join_t1.t1_id + Int64(11)@3], 2), input_partitions=2 -------------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] ---------CoalesceBatchesExec: target_batch_size=4096 -----------RepartitionExec: partitioning=Hash([CAST(join_t2.t2_id AS Int64)@3], 2), input_partitions=2 -------------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(join_t2.t2_id AS Int64)] ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] +ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t1.t1_id + Int64(11)@3, CAST(join_t2.t2_id AS Int64)@3)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t1.t1_id + Int64(11)@3], 2), input_partitions=2 +----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([CAST(join_t2.t2_id AS Int64)@3], 2), input_partitions=2 +----------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(join_t2.t2_id AS Int64)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] # Both side expr key inner join @@ -1498,18 +1511,16 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id] -----ProjectionExec: expr=[t1_id@2 as t1_id, t1_name@3 as t1_name, join_t1.t1_id + UInt32(12)@4 as join_t1.t1_id + UInt32(12), t2_id@0 as t2_id, join_t2.t2_id + UInt32(1)@1 as join_t2.t2_id + UInt32(1)] -------CoalesceBatchesExec: target_batch_size=4096 ---------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t2.t2_id + UInt32(1)@1, join_t1.t1_id + UInt32(12)@2)] -----------CoalescePartitionsExec -------------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as join_t2.t2_id + UInt32(1)] ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] -----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as join_t1.t1_id + UInt32(12)] -------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ---------------MemoryExec: partitions=1, partition_sizes=[1] +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t2.t2_id + UInt32(1)@1, join_t1.t1_id + UInt32(12)@2)] +------CoalescePartitionsExec +--------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as join_t2.t2_id + UInt32(1)] +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as join_t1.t1_id + UInt32(12)] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.optimizer.repartition_joins = true; @@ -1526,21 +1537,19 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id] -----ProjectionExec: expr=[t1_id@2 as t1_id, t1_name@3 as t1_name, join_t1.t1_id + UInt32(12)@4 as join_t1.t1_id + UInt32(12), t2_id@0 as t2_id, join_t2.t2_id + UInt32(1)@1 as join_t2.t2_id + UInt32(1)] -------CoalesceBatchesExec: target_batch_size=4096 ---------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t2.t2_id + UInt32(1)@1, join_t1.t1_id + UInt32(12)@2)] -----------CoalesceBatchesExec: target_batch_size=4096 -------------RepartitionExec: partitioning=Hash([join_t2.t2_id + UInt32(1)@1], 2), input_partitions=2 ---------------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as join_t2.t2_id + UInt32(1)] -----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------------MemoryExec: partitions=1, partition_sizes=[1] -----------CoalesceBatchesExec: target_batch_size=4096 -------------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(12)@2], 2), input_partitions=2 ---------------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as join_t1.t1_id + UInt32(12)] -----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------------MemoryExec: partitions=1, partition_sizes=[1] +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t2.t2_id + UInt32(1)@1, join_t1.t1_id + UInt32(12)@2)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t2.t2_id + UInt32(1)@1], 2), input_partitions=2 +----------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as join_t2.t2_id + UInt32(1)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(12)@2], 2), input_partitions=2 +----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as join_t1.t1_id + UInt32(12)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] # Left side expr key inner join @@ -1560,14 +1569,11 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id] -----CoalesceBatchesExec: target_batch_size=4096 -------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t1.t1_id + UInt32(11)@2, t2_id@0)] ---------CoalescePartitionsExec -----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] -------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ---------------MemoryExec: partitions=1, partition_sizes=[1] +ProjectionExec: expr=[t1_id@1 as t1_id, t2_id@0 as t2_id, t1_name@2 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t2_id@0, join_t1.t1_id + UInt32(11)@2)] +------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------MemoryExec: partitions=1, partition_sizes=[1] @@ -1587,17 +1593,16 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id] -----CoalesceBatchesExec: target_batch_size=4096 -------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t1.t1_id + UInt32(11)@2, t2_id@0)] ---------CoalesceBatchesExec: target_batch_size=4096 -----------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(11)@2], 2), input_partitions=2 -------------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] ---------CoalesceBatchesExec: target_batch_size=4096 -----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +ProjectionExec: expr=[t1_id@1 as t1_id, t2_id@0 as t2_id, t1_name@2 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t2_id@0, join_t1.t1_id + UInt32(11)@2)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(11)@2], 2), input_partitions=2 +----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] @@ -1619,17 +1624,15 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@2 as t2_id] -----ProjectionExec: expr=[t1_id@2 as t1_id, t1_name@3 as t1_name, t2_id@0 as t2_id, join_t2.t2_id - UInt32(11)@1 as join_t2.t2_id - UInt32(11)] -------CoalesceBatchesExec: target_batch_size=4096 ---------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t2.t2_id - UInt32(11)@1, t1_id@0)] -----------CoalescePartitionsExec -------------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t2.t2_id - UInt32(11)@1, t1_id@0)] +------CoalescePartitionsExec +--------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] ----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------MemoryExec: partitions=1, partition_sizes=[1] +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.optimizer.repartition_joins = true; @@ -1647,20 +1650,18 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@2 as t2_id] -----ProjectionExec: expr=[t1_id@2 as t1_id, t1_name@3 as t1_name, t2_id@0 as t2_id, join_t2.t2_id - UInt32(11)@1 as join_t2.t2_id - UInt32(11)] -------CoalesceBatchesExec: target_batch_size=4096 ---------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t2.t2_id - UInt32(11)@1, t1_id@0)] -----------CoalesceBatchesExec: target_batch_size=4096 -------------RepartitionExec: partitioning=Hash([join_t2.t2_id - UInt32(11)@1], 2), input_partitions=2 ---------------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] -----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------------MemoryExec: partitions=1, partition_sizes=[1] -----------CoalesceBatchesExec: target_batch_size=4096 -------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t2.t2_id - UInt32(11)@1, t1_id@0)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t2.t2_id - UInt32(11)@1], 2), input_partitions=2 +----------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] # Select wildcard with expr key inner join @@ -1680,7 +1681,7 @@ Inner Join: join_t1.t1_id = join_t2.t2_id - UInt32(11) --TableScan: join_t2 projection=[t2_id, t2_name, t2_int] physical_plan ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int] ---CoalesceBatchesExec: target_batch_size=4096 +--CoalesceBatchesExec: target_batch_size=2 ----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1_id@0, join_t2.t2_id - UInt32(11)@3)] ------MemoryExec: partitions=1, partition_sizes=[1] ------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] @@ -1703,13 +1704,13 @@ Inner Join: join_t1.t1_id = join_t2.t2_id - UInt32(11) --TableScan: join_t2 projection=[t2_id, t2_name, t2_int] physical_plan ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int] ---CoalesceBatchesExec: target_batch_size=4096 +--CoalesceBatchesExec: target_batch_size=2 ----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, join_t2.t2_id - UInt32(11)@3)] -------CoalesceBatchesExec: target_batch_size=4096 +------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------MemoryExec: partitions=1, partition_sizes=[1] -------CoalesceBatchesExec: target_batch_size=4096 +------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([join_t2.t2_id - UInt32(11)@3], 2), input_partitions=2 ----------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -2032,13 +2033,13 @@ Inner Join: Filter: join_t1.t1_id > join_t2.t2_id ------TableScan: join_t2 projection=[t2_id, t2_int] physical_plan NestedLoopJoinExec: join_type=Inner, filter=t1_id@0 > t2_id@1 ---CoalesceBatchesExec: target_batch_size=4096 +--CoalesceBatchesExec: target_batch_size=2 ----FilterExec: t1_id@0 > 10 ------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 --------MemoryExec: partitions=1, partition_sizes=[1] --CoalescePartitionsExec ----ProjectionExec: expr=[t2_id@0 as t2_id] -------CoalesceBatchesExec: target_batch_size=4096 +------CoalesceBatchesExec: target_batch_size=2 --------FilterExec: t2_int@1 > 1 ----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2073,11 +2074,11 @@ Right Join: Filter: join_t1.t1_id < join_t2.t2_id physical_plan NestedLoopJoinExec: join_type=Right, filter=t1_id@0 < t2_id@1 --CoalescePartitionsExec -----CoalesceBatchesExec: target_batch_size=4096 +----CoalesceBatchesExec: target_batch_size=2 ------FilterExec: t1_id@0 > 22 --------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ----------MemoryExec: partitions=1, partition_sizes=[1] ---CoalesceBatchesExec: target_batch_size=4096 +--CoalesceBatchesExec: target_batch_size=2 ----FilterExec: t2_id@0 > 11 ------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 --------MemoryExec: partitions=1, partition_sizes=[1] @@ -2470,6 +2471,16 @@ test_timestamps_table NULL NULL NULL NULL Row 2 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 +# show the contents of the timestamp with timezone table +query PPPPT +select * from +test_timestamps_tz_table +---- +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +NULL NULL NULL NULL Row 2 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + # test timestamp join on nanos datatype query PPPPTPPPPT rowsort SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_table ) as t2 ON t1.nanos = t2.nanos; @@ -2478,6 +2489,14 @@ SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_ta 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 +# test timestamp with timezone join on nanos datatype +query PPPPTPPPPT rowsort +SELECT * FROM test_timestamps_tz_table as t1 JOIN (SELECT * FROM test_timestamps_tz_table ) as t2 ON t1.nanos = t2.nanos; +---- +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + # test timestamp join on micros datatype query PPPPTPPPPT rowsort SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_table ) as t2 ON t1.micros = t2.micros @@ -2486,6 +2505,14 @@ SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_ta 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 +# test timestamp with timezone join on micros datatype +query PPPPTPPPPT rowsort +SELECT * FROM test_timestamps_tz_table as t1 JOIN (SELECT * FROM test_timestamps_tz_table ) as t2 ON t1.micros = t2.micros +---- +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + # test timestamp join on millis datatype query PPPPTPPPPT rowsort SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_table ) as t2 ON t1.millis = t2.millis @@ -2494,6 +2521,46 @@ SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_ta 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 +# test timestamp with timezone join on millis datatype +query PPPPTPPPPT rowsort +SELECT * FROM test_timestamps_tz_table as t1 JOIN (SELECT * FROM test_timestamps_tz_table ) as t2 ON t1.millis = t2.millis +---- +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + +#### +# Config setup +#### + +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +set datafusion.optimizer.prefer_hash_join = true; + +# explain hash join on timestamp with timezone type +query TT +EXPLAIN SELECT * FROM test_timestamps_tz_table as t1 JOIN test_timestamps_tz_table as t2 ON t1.millis = t2.millis +---- +logical_plan +Inner Join: t1.millis = t2.millis +--SubqueryAlias: t1 +----TableScan: test_timestamps_tz_table projection=[nanos, micros, millis, secs, names] +--SubqueryAlias: t2 +----TableScan: test_timestamps_tz_table projection=[nanos, micros, millis, secs, names] +physical_plan +CoalesceBatchesExec: target_batch_size=2 +--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(millis@2, millis@2)] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([millis@2], 2), input_partitions=2 +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([millis@2], 2), input_partitions=2 +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] + # left_join_using_2 query II SELECT t1.c1, t2.c2 FROM test_partition_table t1 JOIN test_partition_table t2 USING (c2) ORDER BY t2.c2; @@ -2643,7 +2710,7 @@ statement ok set datafusion.execution.target_partitions = 2; statement ok -set datafusion.execution.batch_size = 4096; +set datafusion.execution.batch_size = 2; # explain sort_merge_join_on_date32 inner sort merge join on data type (Date32) query TT @@ -2658,12 +2725,12 @@ Inner Join: t1.c1 = t2.c1 physical_plan SortMergeJoin: join_type=Inner, on=[(c1@0, c1@0)] --SortExec: expr=[c1@0 ASC] -----CoalesceBatchesExec: target_batch_size=4096 +----CoalesceBatchesExec: target_batch_size=2 ------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------MemoryExec: partitions=1, partition_sizes=[1] --SortExec: expr=[c1@0 ASC] -----CoalesceBatchesExec: target_batch_size=4096 +----CoalesceBatchesExec: target_batch_size=2 ------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------MemoryExec: partitions=1, partition_sizes=[1] @@ -2689,13 +2756,13 @@ physical_plan ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@5 as c1, c2@6 as c2, c3@7 as c3, c4@8 as c4] --SortMergeJoin: join_type=Right, on=[(CAST(t1.c3 AS Decimal128(10, 2))@4, c3@2)] ----SortExec: expr=[CAST(t1.c3 AS Decimal128(10, 2))@4 ASC] -------CoalesceBatchesExec: target_batch_size=4096 +------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([CAST(t1.c3 AS Decimal128(10, 2))@4], 2), input_partitions=2 ----------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, CAST(c3@2 AS Decimal128(10, 2)) as CAST(t1.c3 AS Decimal128(10, 2))] ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] ----SortExec: expr=[c3@2 ASC] -------CoalesceBatchesExec: target_batch_size=4096 +------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([c3@2], 2), input_partitions=2 ----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2722,7 +2789,7 @@ statement ok set datafusion.execution.target_partitions = 2; statement ok -set datafusion.execution.batch_size = 4096; +set datafusion.execution.batch_size = 2; @@ -2743,7 +2810,7 @@ statement ok set datafusion.execution.target_partitions = 2; statement ok -set datafusion.execution.batch_size = 4096; +set datafusion.execution.batch_size = 2; query TT explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id IN (SELECT t2_id FROM left_semi_anti_join_table_t2 t2) ORDER BY t1_id @@ -2751,14 +2818,14 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] -----CoalesceBatchesExec: target_batch_size=4096 -------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] ---------CoalesceBatchesExec: target_batch_size=4096 -----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] ---------CoalesceBatchesExec: target_batch_size=4096 -----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2792,14 +2859,14 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] -----CoalesceBatchesExec: target_batch_size=4096 -------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] ---------CoalesceBatchesExec: target_batch_size=4096 -----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] ---------CoalesceBatchesExec: target_batch_size=4096 -----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2827,7 +2894,7 @@ statement ok set datafusion.execution.target_partitions = 2; statement ok -set datafusion.execution.batch_size = 4096; +set datafusion.execution.batch_size = 2; #Test the left_semi_join scenarios where the current repartition_joins parameter is set to false . #### @@ -2846,7 +2913,7 @@ statement ok set datafusion.execution.target_partitions = 2; statement ok -set datafusion.execution.batch_size = 4096; +set datafusion.execution.batch_size = 2; query TT explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id IN (SELECT t2_id FROM left_semi_anti_join_table_t2 t2) ORDER BY t1_id @@ -2854,8 +2921,8 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] -----CoalesceBatchesExec: target_batch_size=4096 -------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] --------MemoryExec: partitions=1, partition_sizes=[1] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------MemoryExec: partitions=1, partition_sizes=[1] @@ -2890,8 +2957,8 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] -----CoalesceBatchesExec: target_batch_size=4096 -------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] --------MemoryExec: partitions=1, partition_sizes=[1] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------MemoryExec: partitions=1, partition_sizes=[1] @@ -2920,7 +2987,7 @@ statement ok set datafusion.execution.target_partitions = 2; statement ok -set datafusion.execution.batch_size = 4096; +set datafusion.execution.batch_size = 2; #Test the right_semi_join scenarios where the current repartition_joins parameter is set to true . @@ -2940,7 +3007,7 @@ statement ok set datafusion.execution.target_partitions = 2; statement ok -set datafusion.execution.batch_size = 4096; +set datafusion.execution.batch_size = 2; query TT explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHERE EXISTS (SELECT * FROM right_semi_anti_join_table_t2 t2 where t2.t2_id = t1.t1_id and t2.t2_name <> t1.t1_name) ORDER BY t1_id @@ -2948,13 +3015,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] -----CoalesceBatchesExec: target_batch_size=4096 +----CoalesceBatchesExec: target_batch_size=2 ------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 ---------CoalesceBatchesExec: target_batch_size=4096 +--------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] ---------CoalesceBatchesExec: target_batch_size=4096 +--------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2970,13 +3037,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] -----CoalesceBatchesExec: target_batch_size=4096 +----CoalesceBatchesExec: target_batch_size=2 ------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 ---------CoalesceBatchesExec: target_batch_size=4096 +--------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] ---------CoalesceBatchesExec: target_batch_size=4096 +--------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] @@ -3002,7 +3069,7 @@ statement ok set datafusion.execution.target_partitions = 2; statement ok -set datafusion.execution.batch_size = 4096; +set datafusion.execution.batch_size = 2; #Test the right_semi_join scenarios where the current repartition_joins parameter is set to false . @@ -3022,7 +3089,7 @@ statement ok set datafusion.execution.target_partitions = 2; statement ok -set datafusion.execution.batch_size = 4096; +set datafusion.execution.batch_size = 2; query TT explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHERE EXISTS (SELECT * FROM right_semi_anti_join_table_t2 t2 where t2.t2_id = t1.t1_id and t2.t2_name <> t1.t1_name) ORDER BY t1_id @@ -3030,7 +3097,7 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] -----CoalesceBatchesExec: target_batch_size=4096 +----CoalesceBatchesExec: target_batch_size=2 ------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 --------MemoryExec: partitions=1, partition_sizes=[1] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -3047,7 +3114,7 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] -----CoalesceBatchesExec: target_batch_size=4096 +----CoalesceBatchesExec: target_batch_size=2 ------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 --------MemoryExec: partitions=1, partition_sizes=[1] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -3074,7 +3141,7 @@ statement ok set datafusion.execution.target_partitions = 2; statement ok -set datafusion.execution.batch_size = 4096; +set datafusion.execution.batch_size = 2; #### @@ -3126,14 +3193,14 @@ physical_plan SortPreservingMergeExec: [rn1@5 ASC NULLS LAST] --SortMergeJoin: join_type=Inner, on=[(a@1, a@1)] ----SortExec: expr=[rn1@5 ASC NULLS LAST] -------CoalesceBatchesExec: target_batch_size=4096 +------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 ----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] --------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] ----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true ----SortExec: expr=[a@1 ASC] -------CoalesceBatchesExec: target_batch_size=4096 +------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 ----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true @@ -3162,12 +3229,12 @@ physical_plan SortPreservingMergeExec: [rn1@10 ASC NULLS LAST] --SortMergeJoin: join_type=Right, on=[(a@1, a@1)] ----SortExec: expr=[a@1 ASC] -------CoalesceBatchesExec: target_batch_size=4096 +------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 ----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true ----SortExec: expr=[rn1@5 ASC NULLS LAST] -------CoalesceBatchesExec: target_batch_size=4096 +------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 ----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] @@ -3203,14 +3270,14 @@ SortPreservingMergeExec: [a@1 ASC,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST,rn1@11 A --SortExec: expr=[a@1 ASC,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST,rn1@11 ASC NULLS LAST] ----SortMergeJoin: join_type=Inner, on=[(a@1, a@1)] ------SortExec: expr=[a@1 ASC] ---------CoalesceBatchesExec: target_batch_size=4096 +--------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] ----------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] ------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true ------SortExec: expr=[a@1 ASC] ---------CoalesceBatchesExec: target_batch_size=4096 +--------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] @@ -3245,7 +3312,7 @@ Sort: r_table.rn1 ASC NULLS LAST --------WindowAggr: windowExpr=[[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] ----------TableScan: annotated_data projection=[a0, a, b, c, d] physical_plan -CoalesceBatchesExec: target_batch_size=4096 +CoalesceBatchesExec: target_batch_size=2 --HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@1, a@1)] ----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true ----ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] @@ -3272,13 +3339,172 @@ Sort: r_table.rn1 ASC NULLS LAST --------WindowAggr: windowExpr=[[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] ----------TableScan: annotated_data projection=[a0, a, b, c, d] physical_plan -CoalesceBatchesExec: target_batch_size=4096 +CoalesceBatchesExec: target_batch_size=2 --HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(a@0, a@1)] ----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a], output_ordering=[a@0 ASC], has_header=true ----ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] ------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] --------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +query TT +EXPLAIN SELECT l.a, LAST_VALUE(r.b ORDER BY r.a ASC NULLS FIRST) as last_col1 +FROM annotated_data as l +JOIN annotated_data as r +ON l.a = r.a +GROUP BY l.a, l.b, l.c +ORDER BY l.a ASC NULLS FIRST; +---- +logical_plan +Sort: l.a ASC NULLS FIRST +--Projection: l.a, LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST] AS last_col1 +----Aggregate: groupBy=[[l.a, l.b, l.c]], aggr=[[LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST]]] +------Inner Join: l.a = r.a +--------SubqueryAlias: l +----------TableScan: annotated_data projection=[a, b, c] +--------SubqueryAlias: r +----------TableScan: annotated_data projection=[a, b] +physical_plan +ProjectionExec: expr=[a@0 as a, LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST]@3 as last_col1] +--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b)], ordering_mode=PartiallySorted([0]) +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0)] +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST], has_header=true + +# create a table where there more than one valid ordering +# that describes table. +statement ok +CREATE EXTERNAL TABLE multiple_ordered_table ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +query TT +EXPLAIN SELECT LAST_VALUE(l.d ORDER BY l.a) AS amount_usd +FROM multiple_ordered_table AS l +INNER JOIN ( + SELECT *, ROW_NUMBER() OVER (ORDER BY r.a) as row_n FROM multiple_ordered_table AS r +) +ON l.d = r.d AND + l.a >= r.a - 10 +GROUP BY row_n +ORDER BY row_n +---- +logical_plan +Projection: amount_usd +--Sort: row_n ASC NULLS LAST +----Projection: LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST] AS amount_usd, row_n +------Aggregate: groupBy=[[row_n]], aggr=[[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]]] +--------Projection: l.a, l.d, row_n +----------Inner Join: l.d = r.d Filter: CAST(l.a AS Int64) >= CAST(r.a AS Int64) - Int64(10) +------------SubqueryAlias: l +--------------TableScan: multiple_ordered_table projection=[a, d] +------------Projection: r.a, r.d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_n +--------------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----------------SubqueryAlias: r +------------------TableScan: multiple_ordered_table projection=[a, d] +physical_plan +ProjectionExec: expr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd] +--AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[LAST_VALUE(l.d)], ordering_mode=Sorted +----ProjectionExec: expr=[a@0 as a, d@1 as d, row_n@4 as row_n] +------CoalesceBatchesExec: target_batch_size=2 +--------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +----------ProjectionExec: expr=[a@0 as a, d@1 as d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] +------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true + +# run query above in multiple partitions +statement ok +set datafusion.execution.target_partitions = 2; + +# use bounded variants +statement ok +set datafusion.optimizer.prefer_existing_sort = true; + +query TT +EXPLAIN SELECT l.a, LAST_VALUE(r.b ORDER BY r.a ASC NULLS FIRST) as last_col1 +FROM annotated_data as l +JOIN annotated_data as r +ON l.a = r.a +GROUP BY l.a, l.b, l.c +ORDER BY l.a ASC NULLS FIRST; +---- +logical_plan +Sort: l.a ASC NULLS FIRST +--Projection: l.a, LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST] AS last_col1 +----Aggregate: groupBy=[[l.a, l.b, l.c]], aggr=[[LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST]]] +------Inner Join: l.a = r.a +--------SubqueryAlias: l +----------TableScan: annotated_data projection=[a, b, c] +--------SubqueryAlias: r +----------TableScan: annotated_data projection=[a, b] +physical_plan +SortPreservingMergeExec: [a@0 ASC] +--SortExec: expr=[a@0 ASC] +----ProjectionExec: expr=[a@0 as a, LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST]@3 as last_col1] +------AggregateExec: mode=FinalPartitioned, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b)] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([a@0, b@1, c@2], 2), input_partitions=2 +------------AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b)] +--------------CoalesceBatchesExec: target_batch_size=2 +----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0)] +------------------CoalesceBatchesExec: target_batch_size=2 +--------------------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +------------------CoalesceBatchesExec: target_batch_size=2 +--------------------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT * +FROM annotated_data as l, annotated_data as r +WHERE l.a > r.a +---- +logical_plan +Inner Join: Filter: l.a > r.a +--SubqueryAlias: l +----TableScan: annotated_data projection=[a0, a, b, c, d] +--SubqueryAlias: r +----TableScan: annotated_data projection=[a0, a, b, c, d] +physical_plan +NestedLoopJoinExec: join_type=Inner, filter=a@0 > a@1 +--RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + +# Currently datafusion cannot pushdown filter conditions with scalar UDF into +# cross join. +query TT +EXPLAIN SELECT * +FROM annotated_data as t1, annotated_data as t2 +WHERE EXAMPLE(t1.a, t2.a) > 3 +---- +logical_plan +Filter: example(CAST(t1.a AS Float64), CAST(t2.a AS Float64)) > Float64(3) +--CrossJoin: +----SubqueryAlias: t1 +------TableScan: annotated_data projection=[a0, a, b, c, d] +----SubqueryAlias: t2 +------TableScan: annotated_data projection=[a0, a, b, c, d] +physical_plan +CoalesceBatchesExec: target_batch_size=2 +--FilterExec: example(CAST(a@1 AS Float64), CAST(a@6 AS Float64)) > 3 +----CrossJoinExec +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + #### # Config teardown #### @@ -3292,5 +3518,8 @@ set datafusion.optimizer.prefer_hash_join = true; statement ok set datafusion.execution.target_partitions = 2; +statement ok +set datafusion.optimizer.prefer_existing_sort = false; + statement ok drop table annotated_data; diff --git a/datafusion/sqllogictest/test_files/json.slt b/datafusion/sqllogictest/test_files/json.slt index 69902f2982dc..c0d5e895f0f2 100644 --- a/datafusion/sqllogictest/test_files/json.slt +++ b/datafusion/sqllogictest/test_files/json.slt @@ -50,16 +50,18 @@ EXPLAIN SELECT count(*) from json_test ---- logical_plan Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ---TableScan: json_test projection=[a] +--TableScan: json_test projection=[] physical_plan AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] --CoalescePartitionsExec ----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] ------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------JsonExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/2.json]]}, projection=[a] +--------JsonExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/2.json]]} -query error DataFusion error: Schema error: No field named mycol\. +query ? SELECT mycol FROM single_nan +---- +NULL statement ok DROP TABLE json_test diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 253ca8f335af..e063d6e8960a 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -294,6 +294,212 @@ query T SELECT c1 FROM aggregate_test_100 LIMIT 1 OFFSET 101 ---- +# +# global limit statistics test +# + +statement ok +CREATE TABLE IF NOT EXISTS t1 (a INT) AS VALUES(1),(2),(3),(4),(5),(6),(7),(8),(9),(10); + +# The aggregate does not need to be computed because the input statistics are exact and +# the number of rows is less than the skip value (OFFSET). +query TT +EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11); +---- +logical_plan +Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--Limit: skip=11, fetch=3 +----TableScan: t1 projection=[], fetch=14 +physical_plan +ProjectionExec: expr=[0 as COUNT(*)] +--PlaceholderRowExec + +query I +SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11); +---- +0 + +# The aggregate does not need to be computed because the input statistics are exact and +# the number of rows is less than or equal to the the "fetch+skip" value (LIMIT+OFFSET). +query TT +EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); +---- +logical_plan +Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--Limit: skip=8, fetch=3 +----TableScan: t1 projection=[], fetch=11 +physical_plan +ProjectionExec: expr=[2 as COUNT(*)] +--PlaceholderRowExec + +query I +SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); +---- +2 + +# The aggregate does not need to be computed because the input statistics are exact and +# an OFFSET, but no LIMIT, is specified. +query TT +EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 OFFSET 8); +---- +logical_plan +Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--Limit: skip=8, fetch=None +----TableScan: t1 projection=[] +physical_plan +ProjectionExec: expr=[2 as COUNT(*)] +--PlaceholderRowExec + +query I +SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); +---- +2 + +# The aggregate needs to be computed because the input statistics are inexact. +query TT +EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); +---- +logical_plan +Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--Projection: +----Limit: skip=6, fetch=3 +------Filter: t1.a > Int32(3) +--------TableScan: t1 projection=[a] +physical_plan +AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] +--CoalescePartitionsExec +----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] +------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------ProjectionExec: expr=[] +----------GlobalLimitExec: skip=6, fetch=3 +------------CoalesceBatchesExec: target_batch_size=8192 +--------------FilterExec: a@0 > 3 +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query I +SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); +---- +1 + +# generate BIGINT data from 1 to 1000 in multiple partitions +statement ok +CREATE TABLE t1000 (i BIGINT) AS +WITH t AS (VALUES (0), (0), (0), (0), (0), (0), (0), (0), (0), (0)) +SELECT ROW_NUMBER() OVER (PARTITION BY t1.column1) FROM t t1, t t2, t t3; + +# verify that there are multiple partitions in the input (i.e. MemoryExec says +# there are 4 partitions) so that this tests multi-partition limit. +query TT +EXPLAIN SELECT DISTINCT i FROM t1000; +---- +logical_plan +Aggregate: groupBy=[[t1000.i]], aggr=[[]] +--TableScan: t1000 projection=[i] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[i@0 as i], aggr=[] +--CoalesceBatchesExec: target_batch_size=8192 +----RepartitionExec: partitioning=Hash([i@0], 4), input_partitions=4 +------AggregateExec: mode=Partial, gby=[i@0 as i], aggr=[] +--------MemoryExec: partitions=4, partition_sizes=[1, 1, 2, 1] + +query I +SELECT i FROM t1000 ORDER BY i DESC LIMIT 3; +---- +1000 +999 +998 + +query I +SELECT i FROM t1000 ORDER BY i LIMIT 3; +---- +1 +2 +3 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t1000 LIMIT 3); +---- +3 + +# limit_multi_partitions +statement ok +CREATE TABLE t15 (i BIGINT); + +query I +INSERT INTO t15 VALUES (1); +---- +1 + +query I +INSERT INTO t15 VALUES (1), (2); +---- +2 + +query I +INSERT INTO t15 VALUES (1), (2), (3); +---- +3 + +query I +INSERT INTO t15 VALUES (1), (2), (3), (4); +---- +4 + +query I +INSERT INTO t15 VALUES (1), (2), (3), (4), (5); +---- +5 + +query I +SELECT COUNT(*) FROM t15; +---- +15 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 1); +---- +1 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 2); +---- +2 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 3); +---- +3 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 4); +---- +4 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 5); +---- +5 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 6); +---- +6 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 7); +---- +7 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 8); +---- +8 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 9); +---- +9 + ######## # Clean up after the test ######## diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt new file mode 100644 index 000000000000..7863bf445499 --- /dev/null +++ b/datafusion/sqllogictest/test_files/map.slt @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE EXTERNAL TABLE data +STORED AS PARQUET +LOCATION '../core/tests/data/parquet_map.parquet'; + +query I +SELECT SUM(ints['bytes']) FROM data; +---- +5636785 + +query I +SELECT SUM(ints['bytes']) FROM data WHERE strings['method'] == 'GET'; +---- +649668 + +query TI +SELECT strings['method'] AS method, COUNT(*) as count FROM data GROUP BY method ORDER BY count DESC; +---- +POST 41 +HEAD 33 +PATCH 30 +OPTION 29 +GET 27 +PUT 25 +DELETE 24 + +query T +SELECT strings['not_found'] FROM data LIMIT 1; +---- + +statement ok +drop table data; + + +# Testing explain on a table with a map filter, registered in test_context.rs. +query TT +explain select * from table_with_map where int_field > 0; +---- +logical_plan +Filter: table_with_map.int_field > Int64(0) +--TableScan: table_with_map projection=[int_field, map_field] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: int_field@0 > 0 +----MemoryExec: partitions=1, partition_sizes=[0] + +statement ok +drop table table_with_map; diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index a3ee307f4940..0fa7ff9c2051 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -293,53 +293,52 @@ select c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 from test_non_nullable_int ---- 0 0 0 0 0 0 0 0 -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c2/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c3/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c4/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c5/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c6/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c7/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c8/0 FROM test_non_nullable_integer - -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c2%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c3%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c4%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c5%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c6%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c7%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c8%0 FROM test_non_nullable_integer statement ok @@ -395,7 +394,7 @@ NaN NaN # abs: return type query TT rowsort -SELECT arrow_typeof(c1), arrow_typeof(c2) FROM test_nullable_float limit 1 +SELECT arrow_typeof(abs(c1)), arrow_typeof(abs(c2)) FROM test_nullable_float limit 1 ---- Float32 Float64 @@ -466,34 +465,48 @@ drop table test_non_nullable_float statement ok CREATE TABLE test_nullable_decimal( - c1 DECIMAL(10, 2), - c2 DECIMAL(38, 10) - ) AS VALUES (0, 0), (NULL, NULL); - -query RR + c1 DECIMAL(10, 2), /* Decimal128 */ + c2 DECIMAL(38, 10), /* Decimal128 with max precision */ + c3 DECIMAL(40, 2), /* Decimal256 */ + c4 DECIMAL(76, 10) /* Decimal256 with max precision */ + ) AS VALUES + (0, 0, 0, 0), + (NULL, NULL, NULL, NULL); + +query RRRR INSERT into test_nullable_decimal values - (-99999999.99, '-9999999999999999999999999999.9999999999'), - (99999999.99, '9999999999999999999999999999.9999999999'); + ( + -99999999.99, + '-9999999999999999999999999999.9999999999', + '-99999999999999999999999999999999999999.99', + '-999999999999999999999999999999999999999999999999999999999999999999.9999999999' + ), + ( + 99999999.99, + '9999999999999999999999999999.9999999999', + '99999999999999999999999999999999999999.99', + '999999999999999999999999999999999999999999999999999999999999999999.9999999999' + ) ---- 2 -query R rowsort +query R SELECT c1*0 FROM test_nullable_decimal WHERE c1 IS NULL; ---- NULL -query R rowsort +query R SELECT c1/0 FROM test_nullable_decimal WHERE c1 IS NULL; ---- NULL -query R rowsort +query R SELECT c1%0 FROM test_nullable_decimal WHERE c1 IS NULL; ---- NULL -query R rowsort +query R SELECT c1*0 FROM test_nullable_decimal WHERE c1 IS NOT NULL; ---- 0 @@ -507,19 +520,24 @@ query error DataFusion error: Arrow error: Divide by zero error SELECT c1%0 FROM test_nullable_decimal WHERE c1 IS NOT NULL; # abs: return type -query TT rowsort -SELECT arrow_typeof(c1), arrow_typeof(c2) FROM test_nullable_decimal limit 1 +query TTTT +SELECT + arrow_typeof(abs(c1)), + arrow_typeof(abs(c2)), + arrow_typeof(abs(c3)), + arrow_typeof(abs(c4)) +FROM test_nullable_decimal limit 1 ---- -Decimal128(10, 2) Decimal128(38, 10) +Decimal128(10, 2) Decimal128(38, 10) Decimal256(40, 2) Decimal256(76, 10) -# abs: Decimal128 -query RR rowsort -SELECT abs(c1), abs(c2) FROM test_nullable_decimal +# abs: decimals +query RRRR rowsort +SELECT abs(c1), abs(c2), abs(c3), abs(c4) FROM test_nullable_decimal ---- -0 0 -99999999.99 9999999999999999999999999999.9999999999 -99999999.99 9999999999999999999999999999.9999999999 -NULL NULL +0 0 0 0 +99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999 +99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999 +NULL NULL NULL NULL statement ok drop table test_nullable_decimal @@ -538,10 +556,10 @@ SELECT c1*0 FROM test_non_nullable_decimal ---- 0 -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1/0 FROM test_non_nullable_decimal -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1%0 FROM test_non_nullable_decimal statement ok diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt new file mode 100644 index 000000000000..3b2b219244f5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Tests for tables that has both metadata on each field as well as metadata on +## the schema itself. +########## + +## Note that table_with_metadata is defined using Rust code +## in the test harness as there is no way to define schema +## with metadata in SQL. + +query IT +select * from table_with_metadata; +---- +1 NULL +NULL bar +3 baz + +query I rowsort +SELECT ( + SELECT id FROM table_with_metadata + ) UNION ( + SELECT id FROM table_with_metadata + ); +---- +1 +3 +NULL + +query I rowsort +SELECT "data"."id" +FROM + ( + (SELECT "id" FROM "table_with_metadata") + UNION + (SELECT "id" FROM "table_with_metadata") + ) as "data", + ( + SELECT "id" FROM "table_with_metadata" + ) as "samples" +WHERE "data"."id" = "samples"."id"; +---- +1 +3 + +statement ok +drop table table_with_metadata; diff --git a/datafusion/sqllogictest/test_files/options.slt b/datafusion/sqllogictest/test_files/options.slt index 5fbb2102f4bf..9366a9b3b3c8 100644 --- a/datafusion/sqllogictest/test_files/options.slt +++ b/datafusion/sqllogictest/test_files/options.slt @@ -33,7 +33,7 @@ Filter: a.c0 < Int32(1) physical_plan CoalesceBatchesExec: target_batch_size=8192 --FilterExec: c0@0 < 1 -----MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +----MemoryExec: partitions=1, partition_sizes=[1] ## # test_disable_coalesce @@ -51,7 +51,7 @@ Filter: a.c0 < Int32(1) --TableScan: a projection=[c0] physical_plan FilterExec: c0@0 < 1 ---MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +--MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.execution.coalesce_batches = true @@ -74,7 +74,7 @@ Filter: a.c0 < Int32(1) physical_plan CoalesceBatchesExec: target_batch_size=1234 --FilterExec: c0@0 < 1 -----MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +----MemoryExec: partitions=1, partition_sizes=[1] statement ok @@ -84,7 +84,7 @@ statement ok drop table a # test datafusion.sql_parser.parse_float_as_decimal -# +# # default option value is false query RR select 10000000000000000000.01, -10000000000000000000.01 @@ -209,5 +209,3 @@ select -123456789.0123456789012345678901234567890 # Restore option to default value statement ok set datafusion.sql_parser.parse_float_as_decimal = false; - - diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index 8148f1c4c7c9..77df9e0bb493 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -441,13 +441,13 @@ physical_plan SortPreservingMergeExec: [result@0 ASC NULLS LAST] --ProjectionExec: expr=[b@1 + a@0 + c@2 as result] ----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_ordering=[a@0 ASC NULLS LAST], has_header=true +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_orderings=[[a@0 ASC NULLS LAST], [b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true statement ok drop table multiple_ordered_table; # Create tables having some ordered columns. In the next step, we will expect to observe that scalar -# functions, such as mathematical functions like atan(), ceil(), sqrt(), or date_time functions +# functions, such as mathematical functions like atan(), ceil(), sqrt(), or date_time functions # like date_bin() and date_trunc(), will maintain the order of its argument columns. statement ok CREATE EXTERNAL TABLE csv_with_timestamps ( @@ -559,7 +559,7 @@ physical_plan SortPreservingMergeExec: [log_c11_base_c12@0 ASC NULLS LAST] --ProjectionExec: expr=[log(CAST(c11@0 AS Float64), c12@1) as log_c11_base_c12] ----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_ordering=[c11@0 ASC NULLS LAST], has_header=true +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC]], has_header=true query TT EXPLAIN SELECT LOG(c12, c11) as log_c12_base_c11 @@ -574,7 +574,7 @@ physical_plan SortPreservingMergeExec: [log_c12_base_c11@0 DESC] --ProjectionExec: expr=[log(c12@1, CAST(c11@0 AS Float64)) as log_c12_base_c11] ----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_ordering=[c11@0 ASC NULLS LAST], has_header=true +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC]], has_header=true statement ok drop table aggregate_test_100; diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt new file mode 100644 index 000000000000..0f26c14f0017 --- /dev/null +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -0,0 +1,357 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# TESTS FOR PARQUET FILES + +# Set 2 partitions for deterministic output plans +statement ok +set datafusion.execution.target_partitions = 2; + +# Create a table as a data source +statement ok +CREATE TABLE src_table ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + date_col DATE +) AS VALUES +(1, 'aaa', 100, 1), +(2, 'bbb', 200, 2), +(3, 'ccc', 300, 3), +(4, 'ddd', 400, 4), +(5, 'eee', 500, 5), +(6, 'fff', 600, 6), +(7, 'ggg', 700, 7), +(8, 'hhh', 800, 8), +(9, 'iii', 900, 9); + +# Setup 2 files, i.e., as many as there are partitions: + +# File 1: +query ITID +COPY (SELECT * FROM src_table LIMIT 3) +TO 'test_files/scratch/parquet/test_table/0.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +# File 2: +query ITID +COPY (SELECT * FROM src_table WHERE int_col > 3 LIMIT 3) +TO 'test_files/scratch/parquet/test_table/1.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +# Create a table from generated parquet files, without ordering: +statement ok +CREATE EXTERNAL TABLE test_table ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + date_col DATE +) +STORED AS PARQUET +WITH HEADER ROW +LOCATION 'test_files/scratch/parquet/test_table'; + +# Basic query: +query ITID +SELECT * FROM test_table ORDER BY int_col; +---- +1 aaa 100 1970-01-02 +2 bbb 200 1970-01-03 +3 ccc 300 1970-01-04 +4 ddd 400 1970-01-05 +5 eee 500 1970-01-06 +6 fff 600 1970-01-07 + +# Check output plan, expect no "output_ordering" clause in the physical_plan -> ParquetExec: +query TT +EXPLAIN SELECT int_col, string_col +FROM test_table +ORDER BY string_col, int_col; +---- +logical_plan +Sort: test_table.string_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST +--TableScan: test_table projection=[int_col, string_col] +physical_plan +SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +--SortExec: expr=[string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet]]}, projection=[int_col, string_col] + +# Tear down test_table: +statement ok +DROP TABLE test_table; + +# Create test_table again, but with ordering: +statement ok +CREATE EXTERNAL TABLE test_table ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + date_col DATE +) +STORED AS PARQUET +WITH HEADER ROW +WITH ORDER (string_col ASC NULLS LAST, int_col ASC NULLS LAST) +LOCATION 'test_files/scratch/parquet/test_table'; + +# Check output plan, expect an "output_ordering" clause in the physical_plan -> ParquetExec: +query TT +EXPLAIN SELECT int_col, string_col +FROM test_table +ORDER BY string_col, int_col; +---- +logical_plan +Sort: test_table.string_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST +--TableScan: test_table projection=[int_col, string_col] +physical_plan +SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +--ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet]]}, projection=[int_col, string_col], output_ordering=[string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST] + +# Add another file to the directory underlying test_table +query ITID +COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) +TO 'test_files/scratch/parquet/test_table/2.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +# Check output plan again, expect no "output_ordering" clause in the physical_plan -> ParquetExec, +# due to there being more files than partitions: +query TT +EXPLAIN SELECT int_col, string_col +FROM test_table +ORDER BY string_col, int_col; +---- +logical_plan +Sort: test_table.string_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST +--TableScan: test_table projection=[int_col, string_col] +physical_plan +SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +--SortExec: expr=[string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/2.parquet]]}, projection=[int_col, string_col] + + +# Perform queries using MIN and MAX +query I +SELECT max(int_col) FROM test_table; +---- +9 + +query T +SELECT min(string_col) FROM test_table; +---- +aaa + +query I +SELECT max(bigint_col) FROM test_table; +---- +900 + +query D +SELECT min(date_col) FROM test_table; +---- +1970-01-02 + +# Clean up +statement ok +DROP TABLE test_table; + +# Setup alltypes_plain table: +statement ok +CREATE EXTERNAL TABLE alltypes_plain ( + id INT NOT NULL, + bool_col BOOLEAN NOT NULL, + tinyint_col TINYINT NOT NULL, + smallint_col SMALLINT NOT NULL, + int_col INT NOT NULL, + bigint_col BIGINT NOT NULL, + float_col FLOAT NOT NULL, + double_col DOUBLE NOT NULL, + date_string_col BYTEA NOT NULL, + string_col VARCHAR NOT NULL, + timestamp_col TIMESTAMP NOT NULL, +) +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../../parquet-testing/data/alltypes_plain.parquet' + +# Test a basic query with a CAST: +query IT +SELECT id, CAST(string_col AS varchar) FROM alltypes_plain +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + +# Clean up +statement ok +DROP TABLE alltypes_plain; + +# Perform SELECT on table with fixed sized binary columns + +statement ok +CREATE EXTERNAL TABLE test_binary +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../core/tests/data/test_binary.parquet'; + +# Check size of table: +query I +SELECT count(ids) FROM test_binary; +---- +466 + +# Do the SELECT query: +query ? +SELECT ids FROM test_binary ORDER BY ids LIMIT 10; +---- +008c7196f68089ab692e4739c5fd16b5 +00a51a7bc5ff8eb1627f8f3dc959dce8 +0166ce1d46129ad104fa4990c6057c91 +03a4893f3285b422820b4cd74c9b9786 +04999ac861e14682cd339eae2cc74359 +04b86bf8f228739fde391f850636a77d +050fb9cf722a709eb94b70b3ee7dc342 +052578a65e8e91b8526b182d40e846e8 +05408e6a403e4296526006e20cc4a45a +0592e6fb7d7169b888a4029b53abb701 + +# Clean up +statement ok +DROP TABLE test_binary; + +# Perform a query with a window function and timestamp data: + +statement ok +CREATE EXTERNAL TABLE timestamp_with_tz +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../core/tests/data/timestamp_with_tz.parquet'; + +# Check size of table: +query I +SELECT COUNT(*) FROM timestamp_with_tz; +---- +131072 + +# Perform the query: +query IPT +SELECT + count, + LAG(timestamp, 1) OVER (ORDER BY timestamp), + arrow_typeof(LAG(timestamp, 1) OVER (ORDER BY timestamp)) +FROM timestamp_with_tz +LIMIT 10; +---- +0 NULL Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +4 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +14 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) + +# Test config listing_table_ignore_subdirectory: + +query ITID +COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) +TO 'test_files/scratch/parquet/test_table/subdir/3.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +statement ok +CREATE EXTERNAL TABLE listing_table +STORED AS PARQUET +WITH HEADER ROW +LOCATION 'test_files/scratch/parquet/test_table/*.parquet'; + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = true; + +# scan file: 0.parquet 1.parquet 2.parquet +query I +select count(*) from listing_table; +---- +9 + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = false; + +# scan file: 0.parquet 1.parquet 2.parquet 3.parquet +query I +select count(*) from listing_table; +---- +12 + +# Clean up +statement ok +DROP TABLE timestamp_with_tz; + +# Test a query from the single_nan data set: +statement ok +CREATE EXTERNAL TABLE single_nan +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../../parquet-testing/data/single_nan.parquet'; + +# Check table size: +query I +SELECT COUNT(*) FROM single_nan; +---- +1 + +# Query for the single NULL: +query R +SELECT mycol FROM single_nan; +---- +NULL + +# Clean up +statement ok +DROP TABLE single_nan; + +statement ok +CREATE EXTERNAL TABLE list_columns +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../../parquet-testing/data/list_columns.parquet'; + +query ?? +SELECT int64_list, utf8_list FROM list_columns +---- +[1, 2, 3] [abc, efg, hij] +[, 1] NULL +[4] [efg, , hij, xyz] + +statement ok +DROP TABLE list_columns; + +# Clean up +statement ok +DROP TABLE listing_table; diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index 937b4c2eccf6..e992a440d0a2 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -480,3 +480,44 @@ select * from t where (i & 3) = 1; ######## statement ok DROP TABLE t; + + +######## +# Test query with bloom filter +# Refer to https://github.com/apache/arrow-datafusion/pull/7821#pullrequestreview-1688062599 +######## + +statement ok +CREATE EXTERNAL TABLE data_index_bloom_encoding_stats STORED AS PARQUET LOCATION '../../parquet-testing/data/data_index_bloom_encoding_stats.parquet'; + +statement ok +set datafusion.execution.parquet.bloom_filter_enabled=true; + +query T +SELECT * FROM data_index_bloom_encoding_stats WHERE "String" = 'foo'; +---- + +query T +SELECT * FROM data_index_bloom_encoding_stats WHERE "String" = 'test'; +---- +test + +query T +SELECT * FROM data_index_bloom_encoding_stats WHERE "String" like '%e%'; +---- +Hello +test +are you +the quick +over +the lazy + +statement ok +set datafusion.execution.parquet.bloom_filter_enabled=false; + + +######## +# Clean up after the test +######## +statement ok +DROP TABLE data_index_bloom_encoding_stats; diff --git a/datafusion/sqllogictest/test_files/projection.slt b/datafusion/sqllogictest/test_files/projection.slt new file mode 100644 index 000000000000..b752f5644b7f --- /dev/null +++ b/datafusion/sqllogictest/test_files/projection.slt @@ -0,0 +1,235 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Projection Statement Tests +########## + +# prepare data +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +statement ok +CREATE EXTERNAL TABLE aggregate_simple ( + c1 FLOAT NOT NULL, + c2 DOUBLE NOT NULL, + c3 BOOLEAN NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../core/tests/data/aggregate_simple.csv' + +statement ok +CREATE TABLE memory_table(a INT NOT NULL, b INT NOT NULL, c INT NOT NULL) AS VALUES +(1, 2, 3), +(10, 12, 12), +(10, 12, 12), +(100, 120, 120); + +statement ok +CREATE TABLE cpu_load_short(host STRING NOT NULL) AS VALUES +('host1'), +('host2'); + +statement ok +CREATE EXTERNAL TABLE test (c1 int, c2 bigint, c3 boolean) +STORED AS CSV LOCATION '../core/tests/data/partitioned_csv'; + +statement ok +CREATE EXTERNAL TABLE test_simple (c1 int, c2 bigint, c3 boolean) +STORED AS CSV LOCATION '../core/tests/data/partitioned_csv/partition-0.csv'; + +# projection same fields +query I rowsort +select (1+1) as a from (select 1 as a) as b; +---- +2 + +# projection type alias +query R rowsort +SELECT c1 as c3 FROM aggregate_simple ORDER BY c3 LIMIT 2; +---- +0.00001 +0.00002 + +# csv query group by avg with projection +query RT rowsort +SELECT avg(c12), c1 FROM aggregate_test_100 GROUP BY c1; +---- +0.410407092638 b +0.486006692713 e +0.487545174661 a +0.488553793875 d +0.660045653644 c + +# parallel projection +query II +SELECT c1, c2 FROM test ORDER BY c1 DESC, c2 ASC +---- +3 0 +3 1 +3 2 +3 3 +3 4 +3 5 +3 6 +3 7 +3 8 +3 9 +3 10 +2 0 +2 1 +2 2 +2 3 +2 4 +2 5 +2 6 +2 7 +2 8 +2 9 +2 10 +1 0 +1 1 +1 2 +1 3 +1 4 +1 5 +1 6 +1 7 +1 8 +1 9 +1 10 +0 0 +0 1 +0 2 +0 3 +0 4 +0 5 +0 6 +0 7 +0 8 +0 9 +0 10 + +# subquery alias case insensitive +query II +SELECT V1.c1, v1.C2 FROM (SELECT test_simple.C1, TEST_SIMPLE.c2 FROM test_simple) V1 ORDER BY v1.c1, V1.C2 LIMIT 1; +---- +0 0 + +# projection on table scan +statement ok +set datafusion.explain.logical_plan_only = true + +query TT +EXPLAIN SELECT c2 FROM test; +---- +logical_plan TableScan: test projection=[c2] + +statement count 44 +select c2 from test; + +statement ok +set datafusion.explain.logical_plan_only = false + +# project cast dictionary +query T +SELECT + CASE + WHEN cpu_load_short.host IS NULL THEN '' + ELSE cpu_load_short.host + END AS host +FROM + cpu_load_short; +---- +host1 +host2 + +# projection on memory scan +query TT +explain select b from memory_table; +---- +logical_plan TableScan: memory_table projection=[b] +physical_plan MemoryExec: partitions=1, partition_sizes=[1] + +query I +select b from memory_table; +---- +2 +12 +12 +120 + +# project column with same name as relation +query I +select a.a from (select 1 as a) as a; +---- +1 + +# project column with filters that cant pushed down always false +query I +select * from (select 1 as a) f where f.a=2; +---- + + +# project column with filters that cant pushed down always true +query I +select * from (select 1 as a) f where f.a=1; +---- +1 + +# project columns in memory without propagation +query I +SELECT column1 as a from (values (1), (2)) f where f.column1 = 2; +---- +2 + +# clean data +statement ok +DROP TABLE aggregate_simple; + +statement ok +DROP TABLE aggregate_test_100; + +statement ok +DROP TABLE memory_table; + +statement ok +DROP TABLE cpu_load_short; + +statement ok +DROP TABLE test; + +statement ok +DROP TABLE test_simple; diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt new file mode 100644 index 000000000000..9d4951c7ecac --- /dev/null +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -0,0 +1,284 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +# Tests for automatically reading files in parallel during scan +########## + +# Set 4 partitions for deterministic output plans +statement ok +set datafusion.execution.target_partitions = 4; + +# automatically partition all files over 1 byte +statement ok +set datafusion.optimizer.repartition_file_min_size = 1; + +################### +### Parquet tests +################### + +# create a single parquet file +# Note filename 2.parquet to test sorting (on local file systems it is often listed before 1.parquet) +statement ok +COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/parquet_table/2.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +statement ok +CREATE EXTERNAL TABLE parquet_table(column1 int) +STORED AS PARQUET +LOCATION 'test_files/scratch/repartition_scan/parquet_table/'; + +query I +select * from parquet_table; +---- +1 +2 +3 +4 +5 + +## Expect to see the scan read the file as "4" groups with even sizes (offsets) +query TT +EXPLAIN SELECT column1 FROM parquet_table WHERE column1 <> 42; +---- +logical_plan +Filter: parquet_table.column1 != Int32(42) +--TableScan: parquet_table projection=[column1], partial_filters=[parquet_table.column1 != Int32(42)] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: column1@0 != 42 +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..101], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:101..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:202..303], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:303..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 + +# disable round robin repartitioning +statement ok +set datafusion.optimizer.enable_round_robin_repartition = false; + +## Expect to see the scan read the file as "4" groups with even sizes (offsets) again +query TT +EXPLAIN SELECT column1 FROM parquet_table WHERE column1 <> 42; +---- +logical_plan +Filter: parquet_table.column1 != Int32(42) +--TableScan: parquet_table projection=[column1], partial_filters=[parquet_table.column1 != Int32(42)] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: column1@0 != 42 +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..101], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:101..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:202..303], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:303..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 + +# enable round robin repartitioning again +statement ok +set datafusion.optimizer.enable_round_robin_repartition = true; + +# create a second parquet file +statement ok +COPY (VALUES (100), (200)) TO 'test_files/scratch/repartition_scan/parquet_table/1.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +## Still expect to see the scan read the file as "4" groups with even sizes. One group should read +## parts of both files. +query TT +EXPLAIN SELECT column1 FROM parquet_table WHERE column1 <> 42 ORDER BY column1; +---- +logical_plan +Sort: parquet_table.column1 ASC NULLS LAST +--Filter: parquet_table.column1 != Int32(42) +----TableScan: parquet_table projection=[column1], partial_filters=[parquet_table.column1 != Int32(42)] +physical_plan +SortPreservingMergeExec: [column1@0 ASC NULLS LAST] +--SortExec: expr=[column1@0 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=8192 +------FilterExec: column1@0 != 42 +--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..200], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:200..394, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..206], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:206..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 + + +## Read the files as though they are ordered + +statement ok +CREATE EXTERNAL TABLE parquet_table_with_order(column1 int) +STORED AS PARQUET +LOCATION 'test_files/scratch/repartition_scan/parquet_table' +WITH ORDER (column1 ASC); + +# output should be ordered +query I +SELECT column1 FROM parquet_table_with_order WHERE column1 <> 42 ORDER BY column1; +---- +1 +2 +3 +4 +5 +100 +200 + +# explain should not have any groups with more than one file +# https://github.com/apache/arrow-datafusion/issues/8451 +query TT +EXPLAIN SELECT column1 FROM parquet_table_with_order WHERE column1 <> 42 ORDER BY column1; +---- +logical_plan +Sort: parquet_table_with_order.column1 ASC NULLS LAST +--Filter: parquet_table_with_order.column1 != Int32(42) +----TableScan: parquet_table_with_order projection=[column1], partial_filters=[parquet_table_with_order.column1 != Int32(42)] +physical_plan +SortPreservingMergeExec: [column1@0 ASC NULLS LAST] +--CoalesceBatchesExec: target_batch_size=8192 +----FilterExec: column1@0 != 42 +------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..197], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..201], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:201..403], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:197..394]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 + +# Cleanup +statement ok +DROP TABLE parquet_table; + +statement ok +DROP TABLE parquet_table_with_order; + + +################### +### CSV tests +################### + +# Since parquet and CSV share most of the same implementation, this test checks +# that the basics are connected properly + +# create a single csv file +statement ok +COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/csv_table/1.csv' +(FORMAT csv, SINGLE_FILE_OUTPUT true, HEADER true); + +statement ok +CREATE EXTERNAL TABLE csv_table(column1 int) +STORED AS csv +WITH HEADER ROW +LOCATION 'test_files/scratch/repartition_scan/csv_table/'; + +query I +select * from csv_table ORDER BY column1; +---- +1 +2 +3 +4 +5 + +## Expect to see the scan read the file as "4" groups with even sizes (offsets) +query TT +EXPLAIN SELECT column1 FROM csv_table WHERE column1 <> 42; +---- +logical_plan +Filter: csv_table.column1 != Int32(42) +--TableScan: csv_table projection=[column1], partial_filters=[csv_table.column1 != Int32(42)] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: column1@0 != 42 +----CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/csv_table/1.csv:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/csv_table/1.csv:5..10], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/csv_table/1.csv:10..15], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/csv_table/1.csv:15..18]]}, projection=[column1], has_header=true + +# Cleanup +statement ok +DROP TABLE csv_table; + + +################### +### JSON tests +################### + +# Since parquet and json share most of the same implementation, this test checks +# that the basics are connected properly + +# create a single json file +statement ok +COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/json_table/1.json' +(FORMAT json, SINGLE_FILE_OUTPUT true); + +statement ok +CREATE EXTERNAL TABLE json_table (column1 int) +STORED AS json +LOCATION 'test_files/scratch/repartition_scan/json_table/'; + +query I +select * from "json_table" ORDER BY column1; +---- +1 +2 +3 +4 +5 + +## Expect to see the scan read the file as "4" groups with even sizes (offsets) +query TT +EXPLAIN SELECT column1 FROM "json_table" WHERE column1 <> 42; +---- +logical_plan +Filter: json_table.column1 != Int32(42) +--TableScan: json_table projection=[column1], partial_filters=[json_table.column1 != Int32(42)] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: column1@0 != 42 +----JsonExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:0..18], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:18..36], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:36..54], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:54..70]]}, projection=[column1] + +# Cleanup +statement ok +DROP TABLE json_table; + + +################### +### Arrow File tests +################### + +## Use pre-existing files we don't have a way to create arrow files yet +## (https://github.com/apache/arrow-datafusion/issues/8504) +statement ok +CREATE EXTERNAL TABLE arrow_table +STORED AS ARROW +LOCATION '../core/tests/data/example.arrow'; + + +# It would be great to see the file read as "4" groups with even sizes (offsets) eventually +# https://github.com/apache/arrow-datafusion/issues/8503 +query TT +EXPLAIN SELECT * FROM arrow_table +---- +logical_plan TableScan: arrow_table projection=[f0, f1, f2] +physical_plan ArrowExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.arrow]]}, projection=[f0, f1, f2] + +# Cleanup +statement ok +DROP TABLE arrow_table; + +################### +### Avro File tests +################### + +## Use pre-existing files we don't have a way to create avro files yet + +statement ok +CREATE EXTERNAL TABLE avro_table +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/simple_enum.avro' + + +# It would be great to see the file read as "4" groups with even sizes (offsets) eventually +query TT +EXPLAIN SELECT * FROM avro_table +---- +logical_plan TableScan: avro_table projection=[f1, f2, f3] +physical_plan AvroExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/avro/simple_enum.avro]]}, projection=[f1, f2, f3] + +# Cleanup +statement ok +DROP TABLE avro_table; diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index e5c1a828492a..9b30699e3fa3 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1878,3 +1878,78 @@ query T SELECT CONCAT('Hello', 'World') ---- HelloWorld + +statement ok +CREATE TABLE simple_string( + letter STRING, + letter2 STRING +) as VALUES + ('A', 'APACHE'), + ('B', 'APACHE'), + ('C', 'APACHE'), + ('D', 'APACHE') +; + +query TT +EXPLAIN SELECT letter, letter = LEFT('APACHE', 1) FROM simple_string; +---- +logical_plan +Projection: simple_string.letter, simple_string.letter = Utf8("A") AS simple_string.letter = left(Utf8("APACHE"),Int64(1)) +--TableScan: simple_string projection=[letter] +physical_plan +ProjectionExec: expr=[letter@0 as letter, letter@0 = A as simple_string.letter = left(Utf8("APACHE"),Int64(1))] +--MemoryExec: partitions=1, partition_sizes=[1] + +query TB +SELECT letter, letter = LEFT('APACHE', 1) FROM simple_string; + ---- +---- +A true +B false +C false +D false + +query TT +EXPLAIN SELECT letter, letter = LEFT(letter2, 1) FROM simple_string; +---- +logical_plan +Projection: simple_string.letter, simple_string.letter = left(simple_string.letter2, Int64(1)) +--TableScan: simple_string projection=[letter, letter2] +physical_plan +ProjectionExec: expr=[letter@0 as letter, letter@0 = left(letter2@1, 1) as simple_string.letter = left(simple_string.letter2,Int64(1))] +--MemoryExec: partitions=1, partition_sizes=[1] + +query TB +SELECT letter, letter = LEFT(letter2, 1) FROM simple_string; +---- +A true +B false +C false +D false + +# test string_temporal_coercion +query BBBBBBBBBB +select + arrow_cast(to_timestamp('2020-01-01 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == '2020-01-01T01:01:11', + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == arrow_cast('2020-01-02T01:01:11', 'LargeUtf8'), + arrow_cast(to_timestamp('2020-01-03 01:01:11.1234567890Z'), 'Time32(Second)') == '01:01:11', + arrow_cast(to_timestamp('2020-01-04 01:01:11.1234567890Z'), 'Time32(Second)') == arrow_cast('01:01:11', 'LargeUtf8'), + arrow_cast(to_timestamp('2020-01-05 01:01:11.1234567890Z'), 'Time64(Microsecond)') == '01:01:11.123456', + arrow_cast(to_timestamp('2020-01-06 01:01:11.1234567890Z'), 'Time64(Microsecond)') == arrow_cast('01:01:11.123456', 'LargeUtf8'), + arrow_cast('2020-01-07', 'Date32') == '2020-01-07', + arrow_cast('2020-01-08', 'Date64') == '2020-01-08', + arrow_cast('2020-01-09', 'Date32') == arrow_cast('2020-01-09', 'LargeUtf8'), + arrow_cast('2020-01-10', 'Date64') == arrow_cast('2020-01-10', 'LargeUtf8') +; +---- +true true true true true true true true true true + +query I +SELECT ALL - CASE WHEN NOT - AVG ( - 41 ) IS NULL THEN 47 WHEN NULL IS NULL THEN COUNT ( * ) END + 93 + - - 44 * 91 + CASE + 44 WHEN - - 21 * 69 - 12 THEN 58 ELSE - 3 END * + + 23 * + 84 * - - 59 +---- +-337914 + +query T +SELECT CASE 3 WHEN 1+2 THEN 'first' WHEN 1+1+1 THEN 'second' END +---- +first diff --git a/datafusion/sqllogictest/test_files/schema_evolution.slt b/datafusion/sqllogictest/test_files/schema_evolution.slt new file mode 100644 index 000000000000..36d54159e24d --- /dev/null +++ b/datafusion/sqllogictest/test_files/schema_evolution.slt @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +# Tests for schema evolution -- reading +# data from different files with different schemas +########## + + +statement ok +CREATE EXTERNAL TABLE parquet_table(a varchar, b int, c float) STORED AS PARQUET +LOCATION 'test_files/scratch/schema_evolution/parquet_table/'; + +# File1 has only columns a and b +statement ok +COPY ( + SELECT column1 as a, column2 as b + FROM ( VALUES ('foo', 1), ('foo', 2), ('foo', 3) ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/1.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + + +# File2 has only b +statement ok +COPY ( + SELECT column1 as b + FROM ( VALUES (10) ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/2.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +# File3 has a column from 'z' which does not appear in the table +# but also values from a which do appear in the table +statement ok +COPY ( + SELECT column1 as z, column2 as a + FROM ( VALUES ('bar', 'foo'), ('blarg', 'foo') ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/3.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +# File4 has data for b and a (reversed) and d +statement ok +COPY ( + SELECT column1 as b, column2 as a, column3 as c + FROM ( VALUES (100, 'foo', 10.5), (200, 'foo', 12.6), (300, 'bzz', 13.7) ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/4.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +# The logical distribution of `a`, `b` and `c` in the files is like this: +# +## File1: +# foo 1 NULL +# foo 2 NULL +# foo 3 NULL +# +## File2: +# NULL 10 NULL +# +## File3: +# foo NULL NULL +# foo NULL NULL +# +## File4: +# foo 100 10.5 +# foo 200 12.6 +# bzz 300 13.7 + +# Show all the data +query TIR rowsort +select * from parquet_table; +---- +NULL 10 NULL +bzz 300 13.7 +foo 1 NULL +foo 100 10.5 +foo 2 NULL +foo 200 12.6 +foo 3 NULL +foo NULL NULL +foo NULL NULL + +# Should see all 7 rows that have 'a=foo' +query TIR rowsort +select * from parquet_table where a = 'foo'; +---- +foo 1 NULL +foo 100 10.5 +foo 2 NULL +foo 200 12.6 +foo 3 NULL +foo NULL NULL +foo NULL NULL + +query TIR rowsort +select * from parquet_table where a != 'foo'; +---- +bzz 300 13.7 + +# this should produce at least one row +query TIR rowsort +select * from parquet_table where a is NULL; +---- +NULL 10 NULL + +query TIR rowsort +select * from parquet_table where b > 5; +---- +NULL 10 NULL +bzz 300 13.7 +foo 100 10.5 +foo 200 12.6 + + +query TIR rowsort +select * from parquet_table where b < 150; +---- +NULL 10 NULL +foo 1 NULL +foo 100 10.5 +foo 2 NULL +foo 3 NULL + +query TIR rowsort +select * from parquet_table where c > 11.0; +---- +bzz 300 13.7 +foo 200 12.6 diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index b09910735809..ea570b99d4dd 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -485,8 +485,7 @@ Projection: select_between_data.c1 >= Int64(2) AND select_between_data.c1 <= Int --TableScan: select_between_data projection=[c1] physical_plan ProjectionExec: expr=[c1@0 >= 2 AND c1@0 <= 3 as select_between_data.c1 BETWEEN Int64(2) AND Int64(3)] ---RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----MemoryExec: partitions=1, partition_sizes=[1] +--MemoryExec: partitions=1, partition_sizes=[1] # TODO: query_get_indexed_field @@ -849,6 +848,41 @@ statement error DataFusion error: Error during planning: EXCLUDE or EXCEPT conta SELECT * EXCLUDE(a, a) FROM table1 +# if EXCEPT all the columns, query should still succeed but return empty +statement ok +SELECT * EXCEPT(a, b, c, d) +FROM table1 + +# EXCLUDE order shouldn't matter +query II +SELECT * EXCLUDE(b, a) +FROM table1 +ORDER BY c +LIMIT 5 +---- +100 1000 +200 2000 + +# EXCLUDE with out of order but duplicate columns should error +statement error DataFusion error: Error during planning: EXCLUDE or EXCEPT contains duplicate column names +SELECT * EXCLUDE(d, b, c, a, a, b, c, d) +FROM table1 + +# avoiding adding an alias if the column name is the same +query TT +EXPLAIN select a as a FROM table1 order by a +---- +logical_plan +Sort: table1.a ASC NULLS LAST +--TableScan: table1 projection=[a] +physical_plan +SortExec: expr=[a@0 ASC NULLS LAST] +--MemoryExec: partitions=1, partition_sizes=[1] + +# ambiguous column references in on join +query error DataFusion error: Schema error: Ambiguous reference to unqualified field a +EXPLAIN select a as a FROM table1 t1 CROSS JOIN table1 t2 order by a + # run below query in multi partitions statement ok set datafusion.execution.target_partitions = 2; @@ -994,8 +1028,79 @@ SortPreservingMergeExec: [c@3 ASC NULLS LAST] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +# When ordering lost during projection, we shouldn't keep the SortExec. +# in the final physical plan. +query TT +EXPLAIN SELECT c2, COUNT(*) +FROM (SELECT c2 +FROM aggregate_test_100 +ORDER BY c1, c2) +GROUP BY c2; +---- +logical_plan +Aggregate: groupBy=[[aggregate_test_100.c2]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--Projection: aggregate_test_100.c2 +----Sort: aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST +------Projection: aggregate_test_100.c2, aggregate_test_100.c1 +--------TableScan: aggregate_test_100 projection=[c1, c2] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[COUNT(*)] +--CoalesceBatchesExec: target_batch_size=8192 +----RepartitionExec: partitioning=Hash([c2@0], 2), input_partitions=2 +------AggregateExec: mode=Partial, gby=[c2@0 as c2], aggr=[COUNT(*)] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2], has_header=true + statement ok drop table annotated_data_finite2; statement ok drop table t; + +statement ok +create table t(x bigint, y bigint) as values (1,2), (1,3); + +query II +select z+1, y from (select x+1 as z, y from t) where y > 1; +---- +3 2 +3 3 + +query TT +EXPLAIN SELECT x/2, x/2+1 FROM t; +---- +logical_plan +Projection: t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2), t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2) + Int64(1) +--Projection: t.x / Int64(2) AS t.x / Int64(2)Int64(2)t.x +----TableScan: t projection=[x] +physical_plan +ProjectionExec: expr=[t.x / Int64(2)Int64(2)t.x@0 as t.x / Int64(2), t.x / Int64(2)Int64(2)t.x@0 + 1 as t.x / Int64(2) + Int64(1)] +--ProjectionExec: expr=[x@0 / 2 as t.x / Int64(2)Int64(2)t.x] +----MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT x/2, x/2+1 FROM t; +---- +0 1 +0 1 + +query TT +EXPLAIN SELECT abs(x), abs(x) + abs(y) FROM t; +---- +logical_plan +Projection: abs(t.x)t.x AS abs(t.x), abs(t.x)t.x AS abs(t.x) + abs(t.y) +--Projection: abs(t.x) AS abs(t.x)t.x, t.y +----TableScan: t projection=[x, y] +physical_plan +ProjectionExec: expr=[abs(t.x)t.x@0 as abs(t.x), abs(t.x)t.x@0 + abs(y@1) as abs(t.x) + abs(t.y)] +--ProjectionExec: expr=[abs(x@0) as abs(t.x)t.x, y@1 as y] +----MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT abs(x), abs(x) + abs(y) FROM t; +---- +1 3 +1 4 + +statement ok +DROP TABLE t; diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index fc14798a3bfe..936dedcc896e 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -58,5 +58,16 @@ select struct(a, b, c) from values; {c0: 2, c1: 2.2, c2: b} {c0: 3, c1: 3.3, c2: c} +# explain struct scalar function with columns #1 +query TT +explain select struct(a, b, c) from values; +---- +logical_plan +Projection: struct(values.a, values.b, values.c) +--TableScan: values projection=[a, b, c] +physical_plan +ProjectionExec: expr=[struct(a@0, b@1, c@2) as struct(values.a,values.b,values.c)] +--MemoryExec: partitions=1, partition_sizes=[1] + statement ok drop table values; diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 2eccb60aad3e..3e0fcb7aa96e 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -15,6 +15,10 @@ # specific language governing permissions and limitations # under the License. +# make sure to a batch size smaller than row number of the table. +statement ok +set datafusion.execution.batch_size = 2; + ############# ## Subquery Tests ############# @@ -45,6 +49,13 @@ CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES (44, 'x', 3), (55, 'w', 3); +statement ok +CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES +(11, 'e', 3), +(22, 'f', 1), +(44, 'g', 3), +(55, 'h', 3); + statement ok CREATE EXTERNAL TABLE IF NOT EXISTS customer ( c_custkey BIGINT, @@ -176,19 +187,18 @@ Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int) AS t2_sum --------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(CAST(t2.t2_int AS Int64))]] ----------TableScan: t2 projection=[t2_id, t2_int] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, SUM(t2.t2_int)@1 as t2_sum] ---ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int)@0 as SUM(t2.t2_int), t2_id@1 as t2_id] -----CoalesceBatchesExec: target_batch_size=8192 -------HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] ---------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] -----------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] -------------CoalesceBatchesExec: target_batch_size=8192 ---------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 -----------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] -------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int)@0 as t2_sum] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] +------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] +--------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +----------CoalesceBatchesExec: target_batch_size=2 +------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 +--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 @@ -211,19 +221,18 @@ Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int * Float64(1)) + Int64(1) AS t2 --------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(CAST(t2.t2_int AS Float64)) AS SUM(t2.t2_int * Float64(1))]] ----------TableScan: t2 projection=[t2_id, t2_int] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, SUM(t2.t2_int * Float64(1)) + Int64(1)@1 as t2_sum] ---ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int * Float64(1)) + Int64(1)@0 as SUM(t2.t2_int * Float64(1)) + Int64(1), t2_id@1 as t2_id] -----CoalesceBatchesExec: target_batch_size=8192 -------HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] ---------ProjectionExec: expr=[SUM(t2.t2_int * Float64(1))@1 + 1 as SUM(t2.t2_int * Float64(1)) + Int64(1), t2_id@0 as t2_id] -----------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] -------------CoalesceBatchesExec: target_batch_size=8192 ---------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 -----------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] -------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int * Float64(1)) + Int64(1)@0 as t2_sum] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] +------ProjectionExec: expr=[SUM(t2.t2_int * Float64(1))@1 + 1 as SUM(t2.t2_int * Float64(1)) + Int64(1), t2_id@0 as t2_id] +--------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] +----------CoalesceBatchesExec: target_batch_size=2 +------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 +--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] +----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] query IR rowsort SELECT t1_id, (SELECT sum(t2_int * 1.0) + 1 FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 @@ -247,16 +256,16 @@ Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int) AS t2_sum ----------TableScan: t2 projection=[t2_id, t2_int] physical_plan ProjectionExec: expr=[t1_id@0 as t1_id, SUM(t2.t2_int)@1 as t2_sum] ---CoalesceBatchesExec: target_batch_size=8192 +--CoalesceBatchesExec: target_batch_size=2 ----HashJoinExec: mode=Partitioned, join_type=Left, on=[(t1_id@0, t2_id@1)] -------CoalesceBatchesExec: target_batch_size=8192 +------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 ----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -------CoalesceBatchesExec: target_batch_size=8192 +------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([t2_id@1], 4), input_partitions=4 ----------ProjectionExec: expr=[SUM(t2.t2_int)@2 as SUM(t2.t2_int), t2_id@0 as t2_id] ------------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id, Utf8("a")@1 as Utf8("a")], aggr=[SUM(t2.t2_int)] ---------------CoalesceBatchesExec: target_batch_size=8192 +--------------CoalesceBatchesExec: target_batch_size=2 ----------------RepartitionExec: partitioning=Hash([t2_id@0, Utf8("a")@1], 4), input_partitions=4 ------------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id, a as Utf8("a")], aggr=[SUM(t2.t2_int)] --------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] @@ -283,21 +292,20 @@ Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int) AS t2_sum ----------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(CAST(t2.t2_int AS Int64))]] ------------TableScan: t2 projection=[t2_id, t2_int] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, SUM(t2.t2_int)@1 as t2_sum] ---ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int)@0 as SUM(t2.t2_int), t2_id@1 as t2_id] -----CoalesceBatchesExec: target_batch_size=8192 -------HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] ---------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] -----------CoalesceBatchesExec: target_batch_size=8192 -------------FilterExec: SUM(t2.t2_int)@1 < 3 ---------------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] -----------------CoalesceBatchesExec: target_batch_size=8192 -------------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 ---------------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] -----------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int)@0 as t2_sum] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] +------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] +--------CoalesceBatchesExec: target_batch_size=2 +----------FilterExec: SUM(t2.t2_int)@1 < 3 +------------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +--------------CoalesceBatchesExec: target_batch_size=2 +----------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 +------------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +--------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id having sum(t2_int) < 3) as t2_sum from t1 @@ -418,6 +426,17 @@ SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int) as t2_int from t1 +#non_aggregated_correlated_scalar_subquery_unique +query II rowsort +SELECT t1_id, (SELECT t3_int FROM t3 WHERE t3.t3_id = t1.t1_id) as t3_int from t1 +---- +11 3 +22 1 +33 NULL +44 3 + + +#non_aggregated_correlated_scalar_subquery statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1_int group by t2_int) as t2_int from t1 @@ -436,7 +455,7 @@ Projection: t1.t1_id, () AS t2_int ------Projection: t2.t2_int --------Filter: t2.t2_int = outer_ref(t1.t1_int) ----------TableScan: t2 ---TableScan: t1 projection=[t1_id] +--TableScan: t1 projection=[t1_id, t1_int] query TT explain SELECT t1_id from t1 where t1_int = (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int limit 1) @@ -483,27 +502,29 @@ query TT explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT sum(t1.t1_int + t2.t2_id) FROM t2 WHERE t1.t1_name = t2.t2_name) ---- logical_plan -Filter: EXISTS () ---Subquery: -----Projection: SUM(outer_ref(t1.t1_int) + t2.t2_id) -------Aggregate: groupBy=[[]], aggr=[[SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] ---------Filter: outer_ref(t1.t1_name) = t2.t2_name -----------TableScan: t2 ---TableScan: t1 projection=[t1_id, t1_name] +Projection: t1.t1_id, t1.t1_name +--Filter: EXISTS () +----Subquery: +------Projection: SUM(outer_ref(t1.t1_int) + t2.t2_id) +--------Aggregate: groupBy=[[]], aggr=[[SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] +----------Filter: outer_ref(t1.t1_name) = t2.t2_name +------------TableScan: t2 +----TableScan: t1 projection=[t1_id, t1_name, t1_int] #support_agg_correlated_columns2 query TT explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT count(*) FROM t2 WHERE t1.t1_name = t2.t2_name having sum(t1_int + t2_id) >0) ---- logical_plan -Filter: EXISTS () ---Subquery: -----Projection: COUNT(*) -------Filter: SUM(outer_ref(t1.t1_int) + t2.t2_id) > Int64(0) ---------Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*), SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] -----------Filter: outer_ref(t1.t1_name) = t2.t2_name -------------TableScan: t2 ---TableScan: t1 projection=[t1_id, t1_name] +Projection: t1.t1_id, t1.t1_name +--Filter: EXISTS () +----Subquery: +------Projection: COUNT(*) +--------Filter: SUM(outer_ref(t1.t1_int) + t2.t2_id) > Int64(0) +----------Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*), SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] +------------Filter: outer_ref(t1.t1_name) = t2.t2_name +--------------TableScan: t2 +----TableScan: t1 projection=[t1_id, t1_name, t1_int] #support_join_correlated_columns query TT @@ -691,7 +712,7 @@ logical_plan Projection: __scalar_sq_1.COUNT(*) AS b --SubqueryAlias: __scalar_sq_1 ----Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] -------TableScan: t1 projection=[t1_id] +------TableScan: t1 projection=[] #simple_uncorrelated_scalar_subquery2 query TT @@ -702,10 +723,10 @@ Projection: __scalar_sq_1.COUNT(*) AS b, __scalar_sq_2.COUNT(Int64(1)) AS COUNT( --Left Join: ----SubqueryAlias: __scalar_sq_1 ------Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ---------TableScan: t1 projection=[t1_id] +--------TableScan: t1 projection=[] ----SubqueryAlias: __scalar_sq_2 ------Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] ---------TableScan: t2 projection=[t2_id] +--------TableScan: t2 projection=[] query II select (select count(*) from t1) as b, (select count(1) from t2) @@ -987,3 +1008,55 @@ SELECT * FROM ON (severity.cron_job_name = jobs.cron_job_name); ---- catan-prod1-daily success catan-prod1-daily high + +##correlated_scalar_subquery_sum_agg_bug +#query TT +#explain +#select t1.t1_int from t1 where +# (select sum(t2_int) is null from t2 where t1.t1_id = t2.t2_id) +#---- +#logical_plan +#Projection: t1.t1_int +#--Inner Join: t1.t1_id = __scalar_sq_1.t2_id +#----TableScan: t1 projection=[t1_id, t1_int] +#----SubqueryAlias: __scalar_sq_1 +#------Projection: t2.t2_id +#--------Filter: SUM(t2.t2_int) IS NULL +#----------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(t2.t2_int)]] +#------------TableScan: t2 projection=[t2_id, t2_int] + +#query I rowsort +#select t1.t1_int from t1 where +# (select sum(t2_int) is null from t2 where t1.t1_id = t2.t2_id) +#---- +#2 +#3 +#4 + +statement ok +create table t(a bigint); + +# Result of query below shouldn't depend on +# number of optimization passes +# See issue: https://github.com/apache/arrow-datafusion/issues/8296 +statement ok +set datafusion.optimizer.max_passes = 1; + +query TT +explain select a/2, a/2 + 1 from t +---- +logical_plan +Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1) +--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a +----TableScan: t projection=[a] + +statement ok +set datafusion.optimizer.max_passes = 3; + +query TT +explain select a/2, a/2 + 1 from t +---- +logical_plan +Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1) +--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a +----TableScan: t projection=[a] diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 88a024e0f9da..8b0f50cedf05 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -46,6 +46,30 @@ statement ok create table ts_data_secs as select arrow_cast(ts / 1000000000, 'Timestamp(Second, None)') as ts, value from ts_data; +########## +## Current date Tests +########## + +query B +select cast(now() as date) = current_date(); +---- +true + +query B +select now() = current_date(); +---- +false + +query B +select current_date() = today(); +---- +true + +query B +select cast(now() as date) = today(); +---- +true + ########## ## Timestamp Handling Tests @@ -217,7 +241,7 @@ SELECT to_timestamp_micros(ts) FROM ts_data_secs LIMIT 3 # to nanos query P -SELECT to_timestamp(ts) FROM ts_data_secs LIMIT 3 +SELECT to_timestamp_nanos(ts) FROM ts_data_secs LIMIT 3 ---- 2020-09-08T13:42:29 2020-09-08T12:42:29 @@ -244,7 +268,7 @@ SELECT to_timestamp_seconds(ts) FROM ts_data_micros LIMIT 3 2020-09-08T11:42:29 -# Original column is micros, convert to nanos and check timestamp +# Original column is micros, convert to seconds and check timestamp query P SELECT to_timestamp(ts) FROM ts_data_micros LIMIT 3 @@ -266,7 +290,7 @@ SELECT from_unixtime(ts / 1000000000) FROM ts_data LIMIT 3; # to_timestamp query I -SELECT COUNT(*) FROM ts_data_nanos where ts > to_timestamp('2020-09-08T12:00:00+00:00') +SELECT COUNT(*) FROM ts_data_nanos where ts > timestamp '2020-09-08T12:00:00+00:00' ---- 2 @@ -291,6 +315,35 @@ SELECT COUNT(*) FROM ts_data_secs where ts > to_timestamp_seconds('2020-09-08T12 ---- 2 + +# to_timestamp float inputs + +query PPP +SELECT to_timestamp(1.1) as c1, cast(1.1 as timestamp) as c2, 1.1::timestamp as c3; +---- +1970-01-01T00:00:01.100 1970-01-01T00:00:01.100 1970-01-01T00:00:01.100 + +query PPP +SELECT to_timestamp(-1.1) as c1, cast(-1.1 as timestamp) as c2, (-1.1)::timestamp as c3; +---- +1969-12-31T23:59:58.900 1969-12-31T23:59:58.900 1969-12-31T23:59:58.900 + +query PPP +SELECT to_timestamp(0.0) as c1, cast(0.0 as timestamp) as c2, 0.0::timestamp as c3; +---- +1970-01-01T00:00:00 1970-01-01T00:00:00 1970-01-01T00:00:00 + +query PPP +SELECT to_timestamp(1.23456789) as c1, cast(1.23456789 as timestamp) as c2, 1.23456789::timestamp as c3; +---- +1970-01-01T00:00:01.234567890 1970-01-01T00:00:01.234567890 1970-01-01T00:00:01.234567890 + +query PPP +SELECT to_timestamp(123456789.123456789) as c1, cast(123456789.123456789 as timestamp) as c2, 123456789.123456789::timestamp as c3; +---- +1973-11-29T21:33:09.123456784 1973-11-29T21:33:09.123456784 1973-11-29T21:33:09.123456784 + + # from_unixtime # 1599566400 is '2020-09-08T12:00:00+00:00' @@ -375,7 +428,7 @@ set datafusion.optimizer.skip_failed_rules = true query P select to_timestamp(a) from (select to_timestamp(1) as a) A; ---- -1970-01-01T00:00:00.000000001 +1970-01-01T00:00:01 # cast_to_timestamp_seconds_twice query P @@ -383,7 +436,6 @@ select to_timestamp_seconds(a) from (select to_timestamp_seconds(1) as a)A ---- 1970-01-01T00:00:01 - # cast_to_timestamp_millis_twice query P select to_timestamp_millis(a) from (select to_timestamp_millis(1) as a)A; @@ -396,11 +448,17 @@ select to_timestamp_micros(a) from (select to_timestamp_micros(1) as a)A; ---- 1970-01-01T00:00:00.000001 +# cast_to_timestamp_nanos_twice +query P +select to_timestamp_nanos(a) from (select to_timestamp_nanos(1) as a)A; +---- +1970-01-01T00:00:00.000000001 + # to_timestamp_i32 query P select to_timestamp(cast (1 as int)); ---- -1970-01-01T00:00:00.000000001 +1970-01-01T00:00:01 # to_timestamp_micros_i32 query P @@ -408,6 +466,12 @@ select to_timestamp_micros(cast (1 as int)); ---- 1970-01-01T00:00:00.000001 +# to_timestamp_nanos_i32 +query P +select to_timestamp_nanos(cast (1 as int)); +---- +1970-01-01T00:00:00.000000001 + # to_timestamp_millis_i32 query P select to_timestamp_millis(cast (1 as int)); @@ -1389,6 +1453,12 @@ SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 20:10:00Z', TIMESTAMPTZ '2020-0 ---- 2022-01-02T00:00:00+07:00 +# coerce TIMESTAMP to TIMESTAMPTZ +query P +SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 20:10:00Z', TIMESTAMP '2020-01-01') +---- +2022-01-01T07:00:00+07:00 + # postgresql: 1 query R SELECT date_part('hour', TIMESTAMPTZ '2000-01-01T01:01:01') as part @@ -1448,6 +1518,30 @@ SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 01:10:00+07', TIMESTAMPTZ '2020 ---- 2021-12-31T00:00:00Z +# postgresql: 2021-12-31 00:00:00+00 +query P +SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 01:10:00+07', '2020-01-01') +---- +2021-12-31T00:00:00Z + +# postgresql: 2021-12-31 00:00:00+00 +query P +SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 01:10:00+07', '2020-01-01T00:00:00Z') +---- +2021-12-31T00:00:00Z + +# postgresql: 2021-12-31 18:00:00+00 +query P +SELECT date_bin('2 hour', TIMESTAMPTZ '2022-01-01 01:10:00+07', '2020-01-01') +---- +2021-12-31T18:00:00Z + +# postgresql: 2021-12-31 18:00:00+00 +query P +SELECT date_bin('2 hour', TIMESTAMPTZ '2022-01-01 01:10:00+07', '2020-01-01T00:00:00Z') +---- +2021-12-31T18:00:00Z + # postgresql: 1 query R SELECT date_part('hour', TIMESTAMPTZ '2000-01-01T01:01:01') as part @@ -1636,14 +1730,11 @@ SELECT TIMESTAMPTZ '2022-01-01 01:10:00 AEST' query P rowsort SELECT TIMESTAMPTZ '2022-01-01 01:10:00 Australia/Sydney' as ts_geo UNION ALL -SELECT TIMESTAMPTZ '2022-01-01 01:10:00 Antarctica/Vostok' as ts_geo - UNION ALL SELECT TIMESTAMPTZ '2022-01-01 01:10:00 Africa/Johannesburg' as ts_geo UNION ALL SELECT TIMESTAMPTZ '2022-01-01 01:10:00 America/Los_Angeles' as ts_geo ---- 2021-12-31T14:10:00Z -2021-12-31T19:10:00Z 2021-12-31T23:10:00Z 2022-01-01T09:10:00Z @@ -1678,3 +1769,128 @@ SELECT TIMESTAMPTZ '2023-03-11 02:00:00 America/Los_Angeles' as ts_geo # postgresql: accepts statement error SELECT TIMESTAMPTZ '2023-03-12 02:00:00 America/Los_Angeles' as ts_geo + + + +########## +## Timezone column tests +########## + +# create a table with a non-UTC time zone. +statement ok +SET TIME ZONE = '+05:00' + +statement ok +CREATE TABLE foo (time TIMESTAMPTZ) AS VALUES + ('2020-01-01T00:00:00+05:00'), + ('2020-01-01T01:00:00+05:00'), + ('2020-01-01T02:00:00+05:00'), + ('2020-01-01T03:00:00+05:00') + +statement ok +SET TIME ZONE = '+00' + +# verify column type +query T +SELECT arrow_typeof(time) FROM foo LIMIT 1 +---- +Timestamp(Nanosecond, Some("+05:00")) + +# check date_trunc +query P +SELECT date_trunc('day', time) FROM foo +---- +2020-01-01T00:00:00+05:00 +2020-01-01T00:00:00+05:00 +2020-01-01T00:00:00+05:00 +2020-01-01T00:00:00+05:00 + +# verify date_trunc column type +query T +SELECT arrow_typeof(date_trunc('day', time)) FROM foo LIMIT 1 +---- +Timestamp(Nanosecond, Some("+05:00")) + +# check date_bin +query P +SELECT date_bin(INTERVAL '1 day', time, '1970-01-01T00:00:00+05:00') FROM foo +---- +2020-01-01T00:00:00+05:00 +2020-01-01T00:00:00+05:00 +2020-01-01T00:00:00+05:00 +2020-01-01T00:00:00+05:00 + +# verify date_trunc column type +query T +SELECT arrow_typeof(date_bin(INTERVAL '1 day', time, '1970-01-01T00:00:00+05:00')) FROM foo LIMIT 1 +---- +Timestamp(Nanosecond, Some("+05:00")) + + +# timestamp comparison with and without timezone +query B +SELECT TIMESTAMPTZ '2022-01-01 20:10:00Z' = TIMESTAMP '2020-01-01' +---- +false + +query B +SELECT TIMESTAMPTZ '2020-01-01 00:00:00Z' = TIMESTAMP '2020-01-01' +---- +true + +# verify timestamp cast with integer input +query PPPPPP +SELECT to_timestamp(null), to_timestamp(0), to_timestamp(1926632005), to_timestamp(1), to_timestamp(-1), to_timestamp(0-1) +---- +NULL 1970-01-01T00:00:00 2031-01-19T23:33:25 1970-01-01T00:00:01 1969-12-31T23:59:59 1969-12-31T23:59:59 + +# verify timestamp syntax stlyes are consistent +query BBBBBBBBBBBBB +SELECT to_timestamp(null) is null as c1, + null::timestamp is null as c2, + cast(null as timestamp) is null as c3, + to_timestamp(0) = 0::timestamp as c4, + to_timestamp(1926632005) = 1926632005::timestamp as c5, + to_timestamp(1) = 1::timestamp as c6, + to_timestamp(-1) = -1::timestamp as c7, + to_timestamp(0-1) = (0-1)::timestamp as c8, + to_timestamp(0) = cast(0 as timestamp) as c9, + to_timestamp(1926632005) = cast(1926632005 as timestamp) as c10, + to_timestamp(1) = cast(1 as timestamp) as c11, + to_timestamp(-1) = cast(-1 as timestamp) as c12, + to_timestamp(0-1) = cast(0-1 as timestamp) as c13 +---- +true true true true true true true true true true true true true + +# verify timestamp output types +query TTT +SELECT arrow_typeof(to_timestamp(1)), arrow_typeof(to_timestamp(null)), arrow_typeof(to_timestamp('2023-01-10 12:34:56.000')) +---- +Timestamp(Nanosecond, None) Timestamp(Nanosecond, None) Timestamp(Nanosecond, None) + +# verify timestamp output types using timestamp literal syntax +query BBBBBB +SELECT arrow_typeof(to_timestamp(1)) = arrow_typeof(1::timestamp) as c1, + arrow_typeof(to_timestamp(null)) = arrow_typeof(null::timestamp) as c2, + arrow_typeof(to_timestamp('2023-01-10 12:34:56.000')) = arrow_typeof('2023-01-10 12:34:56.000'::timestamp) as c3, + arrow_typeof(to_timestamp(1)) = arrow_typeof(cast(1 as timestamp)) as c4, + arrow_typeof(to_timestamp(null)) = arrow_typeof(cast(null as timestamp)) as c5, + arrow_typeof(to_timestamp('2023-01-10 12:34:56.000')) = arrow_typeof(cast('2023-01-10 12:34:56.000' as timestamp)) as c6 +---- +true true true true true true + +# known issues. currently overflows (expects default precision to be microsecond instead of nanoseconds. Work pending) +#verify extreme values +#query PPPPPPPP +#SELECT to_timestamp(-62125747200), to_timestamp(1926632005177), -62125747200::timestamp, 1926632005177::timestamp, cast(-62125747200 as timestamp), cast(1926632005177 as timestamp) +#---- +#0001-04-25T00:00:00 +63022-07-16T12:59:37 0001-04-25T00:00:00 +63022-07-16T12:59:37 0001-04-25T00:00:00 +63022-07-16T12:59:37 + +########## +## Test binary temporal coercion for Date and Timestamp +########## + +query B +select arrow_cast(now(), 'Date64') < arrow_cast('2022-02-02 02:02:02', 'Timestamp(Nanosecond, None)'); +---- +false diff --git a/datafusion/sqllogictest/test_files/topk.slt b/datafusion/sqllogictest/test_files/topk.slt new file mode 100644 index 000000000000..5eba20fdc655 --- /dev/null +++ b/datafusion/sqllogictest/test_files/topk.slt @@ -0,0 +1,232 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for development + +statement ok +create table topk(x int) as values (10), (2), (3), (0), (5), (4), (3), (2), (1), (3), (8); + +query I +select * from topk order by x; +---- +0 +1 +2 +2 +3 +3 +3 +4 +5 +8 +10 + +query I +select * from topk order by x limit 3; +---- +0 +1 +2 + +query I +select * from topk order by x desc limit 3; +---- +10 +8 +5 + + + + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +query TT +explain select * from aggregate_test_100 ORDER BY c13 desc limit 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: aggregate_test_100.c13 DESC NULLS FIRST, fetch=5 +----TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--SortExec: TopK(fetch=5), expr=[c13@12 DESC] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true + + + + +query T +select c13 from aggregate_test_100 ORDER BY c13; +---- +0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm +0keZ5G8BffGwgF2RwQD59TFzMStxCB +0og6hSkhbX8AC1ktFS4kounvTzy8Vo +1aOcrEGd0cOqZe2I5XBOm0nDcwtBZO +2T3wSlHdEmASmO0xcXHnndkKEt6bz8 +3BEOHQsMEFZ58VcNTOJYShTBpAPzbt +4HX6feIvmNXBN7XGqgO4YVBkhu8GDI +4JznSdBajNWhu4hRQwjV1FjTTxY68i +52mKlRE3aHCBZtjECq6sY9OqVf8Dze +56MZa5O1hVtX4c5sbnCfxuX5kDChqI +6FPJlLAcaQ5uokyOWZ9HGdLZObFvOZ +6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW +6oIXZuIPIqEoPBvFmbt2Nxy3tryGUE +6x93sxYioWuq5c9Kkk8oTAAORM7cH0 +802bgTGl6Bk5TlkPYYTxp5JkKyaYUA +8LIh0b6jmDGm87BmIyjdxNIpX4ugjD +90gAtmGEeIqUTbo1ZrxCvWtsseukXC +9UbObCsVkmYpJGcGrgfK90qOnwb2Lj +AFGCj7OWlEB5QfniEFgonMq90Tq5uH +ALuRhobVWbnQTTWZdSOk0iVe8oYFhW +Amn2K87Db5Es3dFQO9cw9cvpAM6h35 +AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz +BJqx5WokrmrrezZA0dUbleMYkG5U2O +BPtQMxnuSPpxMExYV9YkDa6cAN7GP3 +BsM5ZAYifRh5Lw3Y8X1r53I0cTJnfE +C2GT5KVyOPZpgKVl110TyZO0NcJ434 +DuJNG8tufSqW0ZstHqWj3aGvFLMg4A +EcCuckwsF3gV1Ecgmh5v4KM8g1ozif +ErJFw6hzZ5fmI5r8bhE4JzlscnhKZU +F7NSTjWvQJyBburN7CXRUlbgp2dIrA +Fi4rJeTQq4eXj8Lxg3Hja5hBVTVV5u +H5j5ZHy1FGesOAHjkQEDYCucbpKWRu +HKSMQ9nTnwXCJIte1JrM1dtYnDtJ8g +IWl0G3ZlMNf7WT8yjIB49cx7MmYOmr +IZTkHMLvIKuiLjhDjYMmIHxh166we4 +Ig1QcuKsjHXkproePdERo2w0mYzIqd +JHNgc2UCaiXOdmkxwDDyGhRlO0mnBQ +JN0VclewmjwYlSl8386MlWv5rEhWCz +JafwVLSVk5AVoXFuzclesQ000EE2k1 +KJFcmTVjdkCMv94wYCtfHMFhzyRsmH +Ktb7GQ0N1DrxwkCkEUsTaIXk0xYinn +Ld2ej8NEv5zNcqU60FwpHeZKBhfpiV +LiEBxds3X0Uw0lxiYjDqrkAaAwoiIW +MXhhH1Var3OzzJCtI9VNyYvA0q8UyJ +MeSTAXq8gVxVjbEjgkvU9YLte0X9uE +NEhyk8uIx4kEULJGa8qIyFjjBcP2G6 +O66j6PaYuZhEUtqV6fuU7TyjM2WxC5 +OF7fQ37GzaZ5ikA2oMyvleKtgnLjXh +OPwBqCEK5PWTjWaiOyL45u2NLTaDWv +Oq6J4Rx6nde0YlhOIJkFsX2MsSvAQ0 +Ow5PGpfTm4dXCfTDsXAOTatXRoAydR +QEHVvcP8gxI6EMJIrvcnIhgzPNjIvv +QJYm7YRA3YetcBHI5wkMZeLXVmfuNy +QYlaIAnJA6r8rlAb6f59wcxvcPcWFf +RilTlL1tKkPOUFuzmLydHAVZwv1OGl +Sfx0vxv1skzZWT1PqVdoRDdO6Sb6xH +TTQUwpMNSXZqVBKAFvXu7OlWvKXJKX +TtDKUZxzVxsq758G6AWPSYuZgVgbcl +VDhtJkYjAYPykCgOU9x3v7v3t4SO1a +VY0zXmXeksCT8BzvpzpPLbmU9Kp9Y4 +Vp3gmWunM5A7wOC9YW2JroFqTWjvTi +WHmjWk2AY4c6m7DA4GitUx6nmb1yYS +XemNcT1xp61xcM1Qz3wZ1VECCnq06O +Z2sWcQr0qyCJRMHDpRy3aQr7PkHtkK +aDxBtor7Icd9C5hnTvvw5NrIre740e +akiiY5N0I44CMwEnBL6RTBk7BRkxEj +b3b9esRhTzFEawbs6XhpKnD9ojutHB +bgK1r6v3BCTh0aejJUhkA1Hn6idXGp +cBGc0kSm32ylBDnxogG727C0uhZEYZ +cq4WSAIFwx3wwTUS5bp1wCe71R6U5I +dVdvo6nUD5FgCgsbOZLds28RyGTpnx +e2Gh6Ov8XkXoFdJWhl0EjwEHlMDYyG +f9ALCzwDAKmdu7Rk2msJaB1wxe5IBX +fuyvs0w7WsKSlXqJ1e6HFSoLmx03AG +gTpyQnEODMcpsPnJMZC66gh33i3m0b +gpo8K5qtYePve6jyPt6xgJx4YOVjms +gxfHWUF8XgY2KdFxigxvNEXe2V2XMl +i6RQVXKUh7MzuGMDaNclUYnFUAireU +ioEncce3mPOXD2hWhpZpCPWGATG6GU +jQimhdepw3GKmioWUlVSWeBVRKFkY3 +l7uwDoTepWwnAP0ufqtHJS3CRi7RfP +lqhzgLsXZ8JhtpeeUWWNbMz8PHI705 +m6jD0LBIQWaMfenwRCTANI9eOdyyto +mhjME0zBHbrK6NMkytMTQzOssOa1gF +mzbkwXKrPeZnxg2Kn1LRF5hYSsmksS +nYVJnVicpGRqKZibHyBAmtmzBXAFfT +oHJMNvWuunsIMIWFnYG31RCfkOo2V7 +oLZ21P2JEDooxV1pU31cIxQHEeeoLu +okOkcWflkNXIy4R8LzmySyY1EC3sYd +pLk3i59bZwd5KBZrI1FiweYTd5hteG +pTeu0WMjBRTaNRT15rLCuEh3tBJVc5 +qnPOOmslCJaT45buUisMRnM0rc77EK +t6fQUjJejPcjc04wHvHTPe55S65B4V +ukOiFGGFnQJDHFgZxHMpvhD3zybF0M +ukyD7b0Efj7tNlFSRmzZ0IqkEzg2a8 +waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs +wwXqSGKLyBQyPkonlzBNYUJTCo4LRS +xipQ93429ksjNcXPX5326VSg1xJZcW +y7C453hRWd4E7ImjNDWlpexB8nUqjh +ydkwycaISlYSlEq3TlkS2m15I2pcp8 + + +query TIIIIIIIITRRT +select * from aggregate_test_100 ORDER BY c13 desc limit 5; +---- +a 4 -38 20744 762932956 308913475857409919 7 45465 1787652631 878137512938218976 0.7459874 0.021825780392 ydkwycaISlYSlEq3TlkS2m15I2pcp8 +d 1 -98 13630 -1991133944 1184110014998006843 220 2986 225513085 9634106610243643486 0.89651865 0.164088254508 y7C453hRWd4E7ImjNDWlpexB8nUqjh +e 2 52 -12056 -1090239422 9011500141803970147 238 4168 2013662838 12565360638488684051 0.6694766 0.391444365692 xipQ93429ksjNcXPX5326VSg1xJZcW +d 1 -72 25590 1188089983 3090286296481837049 241 832 3542840110 5885937420286765261 0.41980565 0.215354023438 wwXqSGKLyBQyPkonlzBNYUJTCo4LRS +a 1 -5 12636 794623392 2909750622865366631 15 24022 2669374863 4776679784701509574 0.29877836 0.253725340799 waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs + + + +## -- make tiny batches to trigger batch compaction +statement ok +set datafusion.execution.batch_size = 2 + +query TIIIIIIIITRRT +select * from aggregate_test_100 ORDER BY c13 desc limit 5; +---- +a 4 -38 20744 762932956 308913475857409919 7 45465 1787652631 878137512938218976 0.7459874 0.021825780392 ydkwycaISlYSlEq3TlkS2m15I2pcp8 +d 1 -98 13630 -1991133944 1184110014998006843 220 2986 225513085 9634106610243643486 0.89651865 0.164088254508 y7C453hRWd4E7ImjNDWlpexB8nUqjh +e 2 52 -12056 -1090239422 9011500141803970147 238 4168 2013662838 12565360638488684051 0.6694766 0.391444365692 xipQ93429ksjNcXPX5326VSg1xJZcW +d 1 -72 25590 1188089983 3090286296481837049 241 832 3542840110 5885937420286765261 0.41980565 0.215354023438 wwXqSGKLyBQyPkonlzBNYUJTCo4LRS +a 1 -5 12636 794623392 2909750622865366631 15 24022 2669374863 4776679784701509574 0.29877836 0.253725340799 waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs + + +## make an example for dictionary encoding + +statement ok +create table dict as select c1, c2, c3, c13, arrow_cast(c13, 'Dictionary(Int32, Utf8)') as c13_dict from aggregate_test_100; + +query TIIT? +select * from dict order by c13 desc limit 5; +---- +a 4 -38 ydkwycaISlYSlEq3TlkS2m15I2pcp8 ydkwycaISlYSlEq3TlkS2m15I2pcp8 +d 1 -98 y7C453hRWd4E7ImjNDWlpexB8nUqjh y7C453hRWd4E7ImjNDWlpexB8nUqjh +e 2 52 xipQ93429ksjNcXPX5326VSg1xJZcW xipQ93429ksjNcXPX5326VSg1xJZcW +d 1 -72 wwXqSGKLyBQyPkonlzBNYUJTCo4LRS wwXqSGKLyBQyPkonlzBNYUJTCo4LRS +a 1 -5 waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs diff --git a/datafusion/sqllogictest/test_files/tpch/q10.slt.part b/datafusion/sqllogictest/test_files/tpch/q10.slt.part index 708bcb3c9b6f..eb0b66f024de 100644 --- a/datafusion/sqllogictest/test_files/tpch/q10.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q10.slt.part @@ -72,7 +72,7 @@ Limit: skip=0, fetch=10 physical_plan GlobalLimitExec: skip=0, fetch=10 --SortPreservingMergeExec: [revenue@2 DESC], fetch=10 -----SortExec: fetch=10, expr=[revenue@2 DESC] +----SortExec: TopK(fetch=10), expr=[revenue@2 DESC] ------ProjectionExec: expr=[c_custkey@0 as c_custkey, c_name@1 as c_name, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@7 as revenue, c_acctbal@2 as c_acctbal, n_name@4 as n_name, c_address@5 as c_address, c_phone@3 as c_phone, c_comment@6 as c_comment] --------AggregateExec: mode=FinalPartitioned, gby=[c_custkey@0 as c_custkey, c_name@1 as c_name, c_acctbal@2 as c_acctbal, c_phone@3 as c_phone, n_name@4 as n_name, c_address@5 as c_address, c_comment@6 as c_comment], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] ----------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/sqllogictest/test_files/tpch/q11.slt.part b/datafusion/sqllogictest/test_files/tpch/q11.slt.part index 0a045d4f77ca..4efa29e2c0ac 100644 --- a/datafusion/sqllogictest/test_files/tpch/q11.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q11.slt.part @@ -76,7 +76,7 @@ Limit: skip=0, fetch=10 physical_plan GlobalLimitExec: skip=0, fetch=10 --SortPreservingMergeExec: [value@1 DESC], fetch=10 -----SortExec: fetch=10, expr=[value@1 DESC] +----SortExec: TopK(fetch=10), expr=[value@1 DESC] ------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value] --------NestedLoopJoinExec: join_type=Inner, filter=CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Decimal128(38, 15)) > SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)@1 ----------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[SUM(partsupp.ps_supplycost * partsupp.ps_availqty)] diff --git a/datafusion/sqllogictest/test_files/tpch/q13.slt.part b/datafusion/sqllogictest/test_files/tpch/q13.slt.part index bb33c2ad3419..5cf6ace8b27b 100644 --- a/datafusion/sqllogictest/test_files/tpch/q13.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q13.slt.part @@ -56,7 +56,7 @@ Limit: skip=0, fetch=10 physical_plan GlobalLimitExec: skip=0, fetch=10 --SortPreservingMergeExec: [custdist@1 DESC,c_count@0 DESC], fetch=10 -----SortExec: fetch=10, expr=[custdist@1 DESC,c_count@0 DESC] +----SortExec: TopK(fetch=10), expr=[custdist@1 DESC,c_count@0 DESC] ------ProjectionExec: expr=[c_count@0 as c_count, COUNT(*)@1 as custdist] --------AggregateExec: mode=FinalPartitioned, gby=[c_count@0 as c_count], aggr=[COUNT(*)] ----------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/sqllogictest/test_files/tpch/q15.slt.part b/datafusion/sqllogictest/test_files/tpch/q15.slt.part index 4515b8ae1fb4..a872e96acf04 100644 --- a/datafusion/sqllogictest/test_files/tpch/q15.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q15.slt.part @@ -95,20 +95,19 @@ SortPreservingMergeExec: [s_suppkey@0 ASC NULLS LAST] ----------------------------------FilterExec: l_shipdate@3 >= 9496 AND l_shipdate@3 < 9587 ------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], has_header=false ----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([MAX(revenue0.total_revenue)@0], 4), input_partitions=4 ---------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------AggregateExec: mode=Final, gby=[], aggr=[MAX(revenue0.total_revenue)] -------------------CoalescePartitionsExec ---------------------AggregateExec: mode=Partial, gby=[], aggr=[MAX(revenue0.total_revenue)] -----------------------ProjectionExec: expr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] -------------------------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] ---------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 -------------------------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] ---------------------------------ProjectionExec: expr=[l_suppkey@0 as l_suppkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] -----------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------FilterExec: l_shipdate@3 >= 9496 AND l_shipdate@3 < 9587 ---------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], has_header=false +------------RepartitionExec: partitioning=Hash([MAX(revenue0.total_revenue)@0], 4), input_partitions=1 +--------------AggregateExec: mode=Final, gby=[], aggr=[MAX(revenue0.total_revenue)] +----------------CoalescePartitionsExec +------------------AggregateExec: mode=Partial, gby=[], aggr=[MAX(revenue0.total_revenue)] +--------------------ProjectionExec: expr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] +----------------------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +------------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 +----------------------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +------------------------------ProjectionExec: expr=[l_suppkey@0 as l_suppkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] +--------------------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------------------FilterExec: l_shipdate@3 >= 9496 AND l_shipdate@3 < 9587 +------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], has_header=false query ITTTR with revenue0 (supplier_no, total_revenue) as ( diff --git a/datafusion/sqllogictest/test_files/tpch/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/q16.slt.part index 5247fbc90d7c..b93872929fe5 100644 --- a/datafusion/sqllogictest/test_files/tpch/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q16.slt.part @@ -52,9 +52,9 @@ limit 10; logical_plan Limit: skip=0, fetch=10 --Sort: supplier_cnt DESC NULLS FIRST, part.p_brand ASC NULLS LAST, part.p_type ASC NULLS LAST, part.p_size ASC NULLS LAST, fetch=10 -----Projection: group_alias_0 AS part.p_brand, group_alias_1 AS part.p_type, group_alias_2 AS part.p_size, COUNT(alias1) AS supplier_cnt -------Aggregate: groupBy=[[group_alias_0, group_alias_1, group_alias_2]], aggr=[[COUNT(alias1)]] ---------Aggregate: groupBy=[[part.p_brand AS group_alias_0, part.p_type AS group_alias_1, part.p_size AS group_alias_2, partsupp.ps_suppkey AS alias1]], aggr=[[]] +----Projection: part.p_brand, part.p_type, part.p_size, COUNT(alias1) AS supplier_cnt +------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[COUNT(alias1)]] +--------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey AS alias1]], aggr=[[]] ----------LeftAnti Join: partsupp.ps_suppkey = __correlated_sq_1.s_suppkey ------------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size --------------Inner Join: partsupp.ps_partkey = part.p_partkey @@ -68,16 +68,16 @@ Limit: skip=0, fetch=10 physical_plan GlobalLimitExec: skip=0, fetch=10 --SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST], fetch=10 -----SortExec: fetch=10, expr=[supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST] -------ProjectionExec: expr=[group_alias_0@0 as part.p_brand, group_alias_1@1 as part.p_type, group_alias_2@2 as part.p_size, COUNT(alias1)@3 as supplier_cnt] ---------AggregateExec: mode=FinalPartitioned, gby=[group_alias_0@0 as group_alias_0, group_alias_1@1 as group_alias_1, group_alias_2@2 as group_alias_2], aggr=[COUNT(alias1)] +----SortExec: TopK(fetch=10), expr=[supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST] +------ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, COUNT(alias1)@3 as supplier_cnt] +--------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)] ----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([group_alias_0@0, group_alias_1@1, group_alias_2@2], 4), input_partitions=4 ---------------AggregateExec: mode=Partial, gby=[group_alias_0@0 as group_alias_0, group_alias_1@1 as group_alias_1, group_alias_2@2 as group_alias_2], aggr=[COUNT(alias1)] -----------------AggregateExec: mode=FinalPartitioned, gby=[group_alias_0@0 as group_alias_0, group_alias_1@1 as group_alias_1, group_alias_2@2 as group_alias_2, alias1@3 as alias1], aggr=[] +------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2], 4), input_partitions=4 +--------------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)] +----------------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, alias1@3 as alias1], aggr=[] ------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------RepartitionExec: partitioning=Hash([group_alias_0@0, group_alias_1@1, group_alias_2@2, alias1@3], 4), input_partitions=4 -----------------------AggregateExec: mode=Partial, gby=[p_brand@1 as group_alias_0, p_type@2 as group_alias_1, p_size@3 as group_alias_2, ps_suppkey@0 as alias1], aggr=[] +--------------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2, alias1@3], 4), input_partitions=4 +----------------------AggregateExec: mode=Partial, gby=[p_brand@1 as p_brand, p_type@2 as p_type, p_size@3 as p_size, ps_suppkey@0 as alias1], aggr=[] ------------------------CoalesceBatchesExec: target_batch_size=8192 --------------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(ps_suppkey@0, s_suppkey@0)] ----------------------------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/sqllogictest/test_files/tpch/q17.slt.part b/datafusion/sqllogictest/test_files/tpch/q17.slt.part index 50661b9b10a8..4d4aa4b1395f 100644 --- a/datafusion/sqllogictest/test_files/tpch/q17.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q17.slt.part @@ -58,21 +58,19 @@ ProjectionExec: expr=[CAST(SUM(lineitem.l_extendedprice)@0 AS Float64) / 7 as av --------ProjectionExec: expr=[l_extendedprice@1 as l_extendedprice] ----------CoalesceBatchesExec: target_batch_size=8192 ------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@2, l_partkey@1)], filter=CAST(l_quantity@0 AS Decimal128(30, 15)) < Float64(0.2) * AVG(lineitem.l_quantity)@1 ---------------CoalesceBatchesExec: target_batch_size=8192 -----------------RepartitionExec: partitioning=Hash([p_partkey@2], 4), input_partitions=4 -------------------ProjectionExec: expr=[l_quantity@1 as l_quantity, l_extendedprice@2 as l_extendedprice, p_partkey@3 as p_partkey] +--------------ProjectionExec: expr=[l_quantity@1 as l_quantity, l_extendedprice@2 as l_extendedprice, p_partkey@3 as p_partkey] +----------------CoalesceBatchesExec: target_batch_size=8192 +------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)] --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)] -------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 -----------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity, l_extendedprice], has_header=false -------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -----------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey] -------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------FilterExec: p_brand@1 = Brand#23 AND p_container@2 = MED BOX -----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_container], has_header=false +----------------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 +------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity, l_extendedprice], has_header=false +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 +------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey] +--------------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------------FilterExec: p_brand@1 = Brand#23 AND p_container@2 = MED BOX +------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_container], has_header=false --------------ProjectionExec: expr=[CAST(0.2 * CAST(AVG(lineitem.l_quantity)@1 AS Float64) AS Decimal128(30, 15)) as Float64(0.2) * AVG(lineitem.l_quantity), l_partkey@0 as l_partkey] ----------------AggregateExec: mode=FinalPartitioned, gby=[l_partkey@0 as l_partkey], aggr=[AVG(lineitem.l_quantity)] ------------------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/sqllogictest/test_files/tpch/q2.slt.part b/datafusion/sqllogictest/test_files/tpch/q2.slt.part index f98634033bf8..ed950db190bb 100644 --- a/datafusion/sqllogictest/test_files/tpch/q2.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q2.slt.part @@ -102,7 +102,7 @@ Limit: skip=0, fetch=10 physical_plan GlobalLimitExec: skip=0, fetch=10 --SortPreservingMergeExec: [s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST], fetch=10 -----SortExec: fetch=10, expr=[s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST] +----SortExec: TopK(fetch=10), expr=[s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST] ------ProjectionExec: expr=[s_acctbal@5 as s_acctbal, s_name@2 as s_name, n_name@8 as n_name, p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_address@3 as s_address, s_phone@4 as s_phone, s_comment@6 as s_comment] --------CoalesceBatchesExec: target_batch_size=8192 ----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@0, ps_partkey@1), (ps_supplycost@7, MIN(partsupp.ps_supplycost)@0)] @@ -238,7 +238,7 @@ order by p_partkey limit 10; ---- -9828.21 Supplier#000000647 UNITED KINGDOM 13120 Manufacturer#5 x5U7MBZmwfG9 33-258-202-4782 s the slyly even ideas poach fluffily +9828.21 Supplier#000000647 UNITED KINGDOM 13120 Manufacturer#5 x5U7MBZmwfG9 33-258-202-4782 s the slyly even ideas poach fluffily 9508.37 Supplier#000000070 FRANCE 3563 Manufacturer#1 INWNH2w,OOWgNDq0BRCcBwOMQc6PdFDc4 16-821-608-1166 ests sleep quickly express ideas. ironic ideas haggle about the final T 9508.37 Supplier#000000070 FRANCE 17268 Manufacturer#4 INWNH2w,OOWgNDq0BRCcBwOMQc6PdFDc4 16-821-608-1166 ests sleep quickly express ideas. ironic ideas haggle about the final T 9453.01 Supplier#000000802 ROMANIA 10021 Manufacturer#5 ,6HYXb4uaHITmtMBj4Ak57Pd 29-342-882-6463 gular frets. permanently special multipliers believe blithely alongs diff --git a/datafusion/sqllogictest/test_files/tpch/q3.slt.part b/datafusion/sqllogictest/test_files/tpch/q3.slt.part index 634f06d0bf50..85f2d9986c27 100644 --- a/datafusion/sqllogictest/test_files/tpch/q3.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q3.slt.part @@ -61,7 +61,7 @@ Limit: skip=0, fetch=10 physical_plan GlobalLimitExec: skip=0, fetch=10 --SortPreservingMergeExec: [revenue@1 DESC,o_orderdate@2 ASC NULLS LAST], fetch=10 -----SortExec: fetch=10, expr=[revenue@1 DESC,o_orderdate@2 ASC NULLS LAST] +----SortExec: TopK(fetch=10), expr=[revenue@1 DESC,o_orderdate@2 ASC NULLS LAST] ------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@3 as revenue, o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority] --------AggregateExec: mode=FinalPartitioned, gby=[l_orderkey@0 as l_orderkey, o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] ----------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/sqllogictest/test_files/tpch/q9.slt.part b/datafusion/sqllogictest/test_files/tpch/q9.slt.part index fc5f82008dad..5db97f79bdb1 100644 --- a/datafusion/sqllogictest/test_files/tpch/q9.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q9.slt.part @@ -78,7 +78,7 @@ Limit: skip=0, fetch=10 physical_plan GlobalLimitExec: skip=0, fetch=10 --SortPreservingMergeExec: [nation@0 ASC NULLS LAST,o_year@1 DESC], fetch=10 -----SortExec: fetch=10, expr=[nation@0 ASC NULLS LAST,o_year@1 DESC] +----SortExec: TopK(fetch=10), expr=[nation@0 ASC NULLS LAST,o_year@1 DESC] ------ProjectionExec: expr=[nation@0 as nation, o_year@1 as o_year, SUM(profit.amount)@2 as sum_profit] --------AggregateExec: mode=FinalPartitioned, gby=[nation@0 as nation, o_year@1 as o_year], aggr=[SUM(profit.amount)] ----------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 05eaa10dabde..b4e338875e24 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -82,6 +82,11 @@ SELECT 2 as x 1 2 +query I +select count(*) from (select id from t1 union all select id from t2) +---- +6 + # csv_union_all statement ok CREATE EXTERNAL TABLE aggregate_test_100 ( @@ -174,6 +179,76 @@ UNION ALL Alice John +# nested_union +query T rowsort +SELECT name FROM t1 UNION (SELECT name from t2 UNION SELECT name || '_new' from t2) +---- +Alex +Alex_new +Alice +Bob +Bob_new +John +John_new + +# should be un-nested, with a single (logical) aggregate +query TT +EXPLAIN SELECT name FROM t1 UNION (SELECT name from t2 UNION SELECT name || '_new' from t2) +---- +logical_plan +Aggregate: groupBy=[[t1.name]], aggr=[[]] +--Union +----TableScan: t1 projection=[name] +----TableScan: t2 projection=[name] +----Projection: t2.name || Utf8("_new") AS name +------TableScan: t2 projection=[name] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] +--CoalesceBatchesExec: target_batch_size=8192 +----RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=4 +------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=3 +--------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[] +----------UnionExec +------------MemoryExec: partitions=1, partition_sizes=[1] +------------MemoryExec: partitions=1, partition_sizes=[1] +------------ProjectionExec: expr=[name@0 || _new as name] +--------------MemoryExec: partitions=1, partition_sizes=[1] + +# nested_union_all +query T rowsort +SELECT name FROM t1 UNION ALL (SELECT name from t2 UNION ALL SELECT name || '_new' from t2) +---- +Alex +Alex +Alex_new +Alice +Bob +Bob +Bob_new +John +John_new + +# Plan is unnested +query TT +EXPLAIN SELECT name FROM t1 UNION ALL (SELECT name from t2 UNION ALL SELECT name || '_new' from t2) +---- +logical_plan +Union +--TableScan: t1 projection=[name] +--TableScan: t2 projection=[name] +--Projection: t2.name || Utf8("_new") AS name +----TableScan: t2 projection=[name] +physical_plan +UnionExec +--MemoryExec: partitions=1, partition_sizes=[1] +--MemoryExec: partitions=1, partition_sizes=[1] +--ProjectionExec: expr=[name@0 || _new as name] +----MemoryExec: partitions=1, partition_sizes=[1] + +# Make sure to choose a small batch size to introduce parallelism to the plan. +statement ok +set datafusion.execution.batch_size = 2; + # union_with_type_coercion query TT explain @@ -202,33 +277,36 @@ Union ------TableScan: t1 projection=[id, name] physical_plan UnionExec ---ProjectionExec: expr=[id@0 as id, name@1 as name] -----CoalesceBatchesExec: target_batch_size=8192 -------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(id@0, CAST(t2.id AS Int32)@2), (name@1, name@1)] ---------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] -----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 ---------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] -----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---------CoalesceBatchesExec: target_batch_size=8192 +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(id@0, CAST(t2.id AS Int32)@2), (name@1, name@1)] +------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([CAST(t2.id AS Int32)@2, name@1], 4), input_partitions=4 +----------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +--ProjectionExec: expr=[CAST(id@0 AS Int32) as id, name@1 as name] +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(CAST(t2.id AS Int32)@2, id@0), (name@1, name@1)] +--------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([CAST(t2.id AS Int32)@2, name@1], 4), input_partitions=4 ------------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] ---------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---ProjectionExec: expr=[CAST(id@0 AS Int32) as id, name@1 as name] -----ProjectionExec: expr=[id@0 as id, name@1 as name] -------CoalesceBatchesExec: target_batch_size=8192 ---------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(CAST(t2.id AS Int32)@2, id@0), (name@1, name@1)] -----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([CAST(t2.id AS Int32)@2, name@1], 4), input_partitions=4 ---------------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] -----------------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] -------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 -----------------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] -------------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 ---------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +--------------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] +----------------CoalesceBatchesExec: target_batch_size=2 +------------------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 +--------------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] +----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[1] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] + query IT rowsort ( @@ -273,26 +351,30 @@ Union ----TableScan: t1 projection=[name] physical_plan InterleaveExec ---CoalesceBatchesExec: target_batch_size=8192 +--CoalesceBatchesExec: target_batch_size=2 ----HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(name@0, name@0)] ------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] ---------CoalesceBatchesExec: target_batch_size=8192 +--------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=4 ------------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[] ---------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -------CoalesceBatchesExec: target_batch_size=8192 +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=4 -----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] +--CoalesceBatchesExec: target_batch_size=2 ----HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(name@0, name@0)] ------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] ---------CoalesceBatchesExec: target_batch_size=8192 +--------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=4 ------------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[] ---------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -------CoalesceBatchesExec: target_batch_size=8192 +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=4 -----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] # union_upcast_types query TT @@ -350,15 +432,17 @@ ProjectionExec: expr=[COUNT(*)@1 as COUNT(*)] --AggregateExec: mode=SinglePartitioned, gby=[name@0 as name], aggr=[COUNT(*)] ----InterleaveExec ------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] ---------CoalesceBatchesExec: target_batch_size=8192 +--------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=4 ------------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[] ---------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] ------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] ---------CoalesceBatchesExec: target_batch_size=8192 +--------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=4 ------------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[] ---------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] ######## @@ -464,15 +548,14 @@ physical_plan UnionExec --ProjectionExec: expr=[Int64(1)@0 as a] ----AggregateExec: mode=FinalPartitioned, gby=[Int64(1)@0 as Int64(1)], aggr=[] -------CoalesceBatchesExec: target_batch_size=8192 ---------RepartitionExec: partitioning=Hash([Int64(1)@0], 4), input_partitions=4 +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([Int64(1)@0], 4), input_partitions=1 ----------AggregateExec: mode=Partial, gby=[1 as Int64(1)], aggr=[] -------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------EmptyExec: produce_one_row=true +------------PlaceholderRowExec --ProjectionExec: expr=[2 as a] -----EmptyExec: produce_one_row=true +----PlaceholderRowExec --ProjectionExec: expr=[3 as a] -----EmptyExec: produce_one_row=true +----PlaceholderRowExec # test UNION ALL aliases correctly with aliased subquery query TT @@ -496,16 +579,11 @@ physical_plan UnionExec --ProjectionExec: expr=[COUNT(*)@1 as count, n@0 as n] ----AggregateExec: mode=FinalPartitioned, gby=[n@0 as n], aggr=[COUNT(*)] -------CoalesceBatchesExec: target_batch_size=8192 ---------RepartitionExec: partitioning=Hash([n@0], 4), input_partitions=4 +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([n@0], 4), input_partitions=1 ----------AggregateExec: mode=Partial, gby=[n@0 as n], aggr=[COUNT(*)] -------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------ProjectionExec: expr=[5 as n] -----------------EmptyExec: produce_one_row=true ---ProjectionExec: expr=[x@0 as count, y@1 as n] -----ProjectionExec: expr=[1 as x, MAX(Int64(10))@0 as y] -------AggregateExec: mode=Final, gby=[], aggr=[MAX(Int64(10))] ---------CoalescePartitionsExec -----------AggregateExec: mode=Partial, gby=[], aggr=[MAX(Int64(10))] -------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------EmptyExec: produce_one_row=true +------------ProjectionExec: expr=[5 as n] +--------------PlaceholderRowExec +--ProjectionExec: expr=[1 as count, MAX(Int64(10))@0 as n] +----AggregateExec: mode=Single, gby=[], aggr=[MAX(Int64(10))] +------PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt new file mode 100644 index 000000000000..6412c3ca859e --- /dev/null +++ b/datafusion/sqllogictest/test_files/update.slt @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Update Tests +########## + +statement ok +create table t1(a int, b varchar, c double, d int); + +# Turn off the optimizer to make the logical plan closer to the initial one +statement ok +set datafusion.optimizer.max_passes = 0; + +query TT +explain update t1 set a=1, b=2, c=3.0, d=NULL; +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: CAST(Int64(1) AS Int32) AS a, CAST(Int64(2) AS Utf8) AS b, Float64(3) AS c, CAST(NULL AS Int32) AS d +----TableScan: t1 + +query TT +explain update t1 set a=c+1, b=a, c=c+1.0, d=b; +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: CAST(t1.c + CAST(Int64(1) AS Float64) AS Int32) AS a, CAST(t1.a AS Utf8) AS b, t1.c + Float64(1) AS c, CAST(t1.b AS Int32) AS d +----TableScan: t1 + +statement ok +create table t2(a int, b varchar, c double, d int); + +## set from subquery +query TT +explain update t1 set b = (select max(b) from t2 where t1.a = t2.a) +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: t1.a AS a, () AS b, t1.c AS c, t1.d AS d +----Subquery: +------Projection: MAX(t2.b) +--------Aggregate: groupBy=[[]], aggr=[[MAX(t2.b)]] +----------Filter: outer_ref(t1.a) = t2.a +------------TableScan: t2 +----TableScan: t1 + +# set from other table +query TT +explain update t1 set b = t2.b, c = t2.a, d = 1 from t2 where t1.a = t2.a and t1.b > 'foo' and t2.c > 1.0; +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d +----Filter: t1.a = t2.a AND t1.b > Utf8("foo") AND t2.c > Float64(1) +------CrossJoin: +--------TableScan: t1 +--------TableScan: t2 + +statement ok +create table t3(a int, b varchar, c double, d int); + +# set from mutiple tables, sqlparser only supports from one table +query error DataFusion error: SQL error: ParserError\("Expected end of statement, found: ,"\) +explain update t1 set b = t2.b, c = t3.a, d = 1 from t2, t3 where t1.a = t2.a and t1.a = t3.a; + +# test table alias +query TT +explain update t1 as T set b = t2.b, c = t.a, d = 1 from t2 where t.a = t2.a and t.b > 'foo' and t2.c > 1.0; +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: t.a AS a, t2.b AS b, CAST(t.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d +----Filter: t.a = t2.a AND t.b > Utf8("foo") AND t2.c > Float64(1) +------CrossJoin: +--------SubqueryAlias: t +----------TableScan: t1 +--------TableScan: t2 diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 3d9f7511be26..100c2143837a 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -279,13 +279,13 @@ SortPreservingMergeExec: [b@0 ASC NULLS LAST] ------------AggregateExec: mode=Partial, gby=[b@1 as b], aggr=[MAX(d.a)] --------------UnionExec ----------------ProjectionExec: expr=[1 as a, aa as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec ----------------ProjectionExec: expr=[3 as a, aa as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec ----------------ProjectionExec: expr=[5 as a, bb as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec ----------------ProjectionExec: expr=[7 as a, bb as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec # Check actual result: query TI @@ -357,7 +357,7 @@ Sort: d.b ASC NULLS LAST physical_plan SortPreservingMergeExec: [b@0 ASC NULLS LAST] --ProjectionExec: expr=[b@0 as b, MAX(d.a)@1 as max_a, MAX(d.seq)@2 as MAX(d.seq)] -----AggregateExec: mode=SinglePartitioned, gby=[b@2 as b], aggr=[MAX(d.a), MAX(d.seq)], ordering_mode=FullyOrdered +----AggregateExec: mode=SinglePartitioned, gby=[b@2 as b], aggr=[MAX(d.a), MAX(d.seq)], ordering_mode=Sorted ------ProjectionExec: expr=[ROW_NUMBER() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as seq, a@0 as a, b@1 as b] --------BoundedWindowAggExec: wdw=[ROW_NUMBER() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }], mode=[Sorted] ----------SortExec: expr=[b@1 ASC NULLS LAST,a@0 ASC NULLS LAST] @@ -365,13 +365,13 @@ SortPreservingMergeExec: [b@0 ASC NULLS LAST] --------------RepartitionExec: partitioning=Hash([b@1], 4), input_partitions=4 ----------------UnionExec ------------------ProjectionExec: expr=[1 as a, aa as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec ------------------ProjectionExec: expr=[3 as a, aa as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec ------------------ProjectionExec: expr=[5 as a, bb as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec ------------------ProjectionExec: expr=[7 as a, bb as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec # check actual result @@ -895,14 +895,14 @@ SELECT statement ok create table temp as values -(1664264591000000000), -(1664264592000000000), -(1664264592000000000), -(1664264593000000000), -(1664264594000000000), -(1664364594000000000), -(1664464594000000000), -(1664564594000000000); +(1664264591), +(1664264592), +(1664264592), +(1664264593), +(1664264594), +(1664364594), +(1664464594), +(1664564594); statement ok create table t as select cast(column1 as timestamp) as ts from temp; @@ -1731,26 +1731,28 @@ logical_plan Projection: COUNT(*) AS global_count --Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ----SubqueryAlias: a -------Sort: aggregate_test_100.c1 ASC NULLS LAST ---------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] -----------Projection: aggregate_test_100.c1 -------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") ---------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] +------Projection: +--------Sort: aggregate_test_100.c1 ASC NULLS LAST +----------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] +------------Projection: aggregate_test_100.c1 +--------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") +----------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] physical_plan ProjectionExec: expr=[COUNT(*)@0 as global_count] --AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] ----CoalescePartitionsExec ------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=2 -----------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] -------------CoalesceBatchesExec: target_batch_size=4096 ---------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 -----------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] -------------------ProjectionExec: expr=[c1@0 as c1] ---------------------CoalesceBatchesExec: target_batch_size=4096 -----------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434 -------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ---------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], has_header=true +----------ProjectionExec: expr=[] +------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] +--------------CoalesceBatchesExec: target_batch_size=4096 +----------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] +--------------------ProjectionExec: expr=[c1@0 as c1] +----------------------CoalesceBatchesExec: target_batch_size=4096 +------------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434 +--------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], has_header=true query I SELECT count(*) as global_count FROM @@ -1957,7 +1959,7 @@ Sort: aggregate_test_100.c1 ASC NULLS LAST ----WindowAggr: windowExpr=[[ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] ------TableScan: aggregate_test_100 projection=[c1] physical_plan -SortPreservingMergeExec: [c1@0 ASC NULLS LAST] +SortPreservingMergeExec: [c1@0 ASC NULLS LAST,rn1@1 ASC NULLS LAST] --ProjectionExec: expr=[c1@0 as c1, ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as rn1] ----BoundedWindowAggExec: wdw=[ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] ------SortExec: expr=[c1@0 ASC NULLS LAST] @@ -2015,7 +2017,7 @@ ProjectionExec: expr=[ARRAY_AGG(aggregate_test_100.c13)@0 as array_agg1] ------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------GlobalLimitExec: skip=0, fetch=1 -------------SortExec: fetch=1, expr=[c13@0 ASC NULLS LAST] +------------SortExec: TopK(fetch=1), expr=[c13@0 ASC NULLS LAST] --------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c13], has_header=true @@ -2075,7 +2077,7 @@ Limit: skip=0, fetch=5 ----------------TableScan: aggregate_test_100 projection=[c1, c2, c8, c9] physical_plan GlobalLimitExec: skip=0, fetch=5 ---SortExec: fetch=5, expr=[c9@0 ASC NULLS LAST] +--SortExec: TopK(fetch=5), expr=[c9@0 ASC NULLS LAST] ----ProjectionExec: expr=[c9@2 as c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as sum2, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum3, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as sum4] ------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] --------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c9@3 as c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@6 as SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING] @@ -2136,15 +2138,12 @@ ProjectionExec: expr=[c9@1 as c9, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER B ----BoundedWindowAggExec: wdw=[SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] ------ProjectionExec: expr=[c2@0 as c2, c9@2 as c9, c1_alias@3 as c1_alias, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@6 as SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING] --------WindowAggExec: wdw=[SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)) }] -----------SortExec: expr=[c2@0 ASC NULLS LAST,c1_alias@3 ASC NULLS LAST,c9@2 ASC NULLS LAST,c8@1 ASC NULLS LAST] -------------ProjectionExec: expr=[c2@1 as c2, c8@2 as c8, c9@3 as c9, c1_alias@4 as c1_alias, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING] ---------------BoundedWindowAggExec: wdw=[SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] -----------------WindowAggExec: wdw=[SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)) }] -------------------SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST,c9@3 ASC NULLS LAST,c8@2 ASC NULLS LAST] ---------------------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c8@2 as c8, c9@3 as c9, c1@0 as c1_alias] -----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c8, c9], has_header=true - - +----------ProjectionExec: expr=[c2@1 as c2, c8@2 as c8, c9@3 as c9, c1_alias@4 as c1_alias, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING] +------------BoundedWindowAggExec: wdw=[SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +--------------WindowAggExec: wdw=[SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)) }] +----------------SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST,c9@3 ASC NULLS LAST,c8@2 ASC NULLS LAST] +------------------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c8@2 as c8, c9@3 as c9, c1@0 as c1_alias] +--------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c8, c9], has_header=true query IIIII SELECT c9, @@ -2182,7 +2181,7 @@ Projection: sum1, sum2 physical_plan ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2] --GlobalLimitExec: skip=0, fetch=5 -----SortExec: fetch=5, expr=[c9@2 ASC NULLS LAST] +----SortExec: TopK(fetch=5), expr=[c9@2 ASC NULLS LAST] ------ProjectionExec: expr=[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum1, SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING@4 as sum2, c9@1 as c9] --------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING: Ok(Field { name: "SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)) }], mode=[Sorted] ----------ProjectionExec: expr=[c1@0 as c1, c9@2 as c9, c12@3 as c12, SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING] @@ -2299,7 +2298,7 @@ Limit: skip=0, fetch=5 ----------TableScan: aggregate_test_100 projection=[c9] physical_plan GlobalLimitExec: skip=0, fetch=5 ---SortExec: fetch=5, expr=[rn1@1 DESC] +--SortExec: TopK(fetch=5), expr=[rn1@1 DESC] ----ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] ------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] --------SortExec: expr=[c9@0 DESC] @@ -2342,10 +2341,11 @@ Limit: skip=0, fetch=5 ----------TableScan: aggregate_test_100 projection=[c9] physical_plan GlobalLimitExec: skip=0, fetch=5 ---ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] -----BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] -------SortExec: expr=[c9@0 DESC] ---------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true +--SortExec: TopK(fetch=5), expr=[rn1@1 ASC NULLS LAST,c9@0 ASC NULLS LAST] +----ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] +------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------SortExec: expr=[c9@0 DESC] +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true query II SELECT c9, rn1 FROM (SELECT c9, @@ -2550,7 +2550,7 @@ Projection: sum1, sum2, sum3, min1, min2, min3, max1, max2, max3, cnt1, cnt2, su physical_plan ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, sum3@2 as sum3, min1@3 as min1, min2@4 as min2, min3@5 as min3, max1@6 as max1, max2@7 as max2, max3@8 as max3, cnt1@9 as cnt1, cnt2@10 as cnt2, sumr1@11 as sumr1, sumr2@12 as sumr2, sumr3@13 as sumr3, minr1@14 as minr1, minr2@15 as minr2, minr3@16 as minr3, maxr1@17 as maxr1, maxr2@18 as maxr2, maxr3@19 as maxr3, cntr1@20 as cntr1, cntr2@21 as cntr2, sum4@22 as sum4, cnt3@23 as cnt3] --GlobalLimitExec: skip=0, fetch=5 -----SortExec: fetch=5, expr=[inc_col@24 DESC] +----SortExec: TopK(fetch=5), expr=[inc_col@24 DESC] ------ProjectionExec: expr=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as sum1, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@14 as sum2, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@15 as sum3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as min1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@17 as min2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as min3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as max1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@20 as max2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@21 as max3, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@22 as cnt1, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@23 as cnt2, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@2 as sumr1, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@3 as sumr2, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sumr3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as minr1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@6 as minr2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@7 as minr3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@8 as maxr1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@9 as maxr2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as maxr3, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@11 as cntr1, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@12 as cntr2, SUM(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@24 as sum4, COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@25 as cnt3, inc_col@0 as inc_col] --------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }, COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }], mode=[Sorted] ----------ProjectionExec: expr=[inc_col@2 as inc_col, desc_col@3 as desc_col, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@4 as SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@5 as SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@8 as MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@9 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@11 as MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@12 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@13 as COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@14 as COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@16 as SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@17 as SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@18 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@19 as MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@22 as MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@23 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@24 as COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@25 as COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING] @@ -2601,6 +2601,7 @@ SELECT # test_source_sorted_builtin query TT EXPLAIN SELECT + ts, FIRST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as fv1, FIRST_VALUE(inc_col) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as fv2, LAST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as lv1, @@ -2630,24 +2631,23 @@ EXPLAIN SELECT LIMIT 5; ---- logical_plan -Projection: fv1, fv2, lv1, lv2, nv1, nv2, rn1, rn2, rank1, rank2, dense_rank1, dense_rank2, lag1, lag2, lead1, lead2, fvr1, fvr2, lvr1, lvr2, lagr1, lagr2, leadr1, leadr2 ---Limit: skip=0, fetch=5 -----Sort: annotated_data_finite.ts DESC NULLS FIRST, fetch=5 -------Projection: FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2, annotated_data_finite.ts ---------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] -----------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] -------------TableScan: annotated_data_finite projection=[ts, inc_col] +Limit: skip=0, fetch=5 +--Sort: annotated_data_finite.ts DESC NULLS FIRST, fetch=5 +----Projection: annotated_data_finite.ts, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2 +------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +--------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +----------TableScan: annotated_data_finite projection=[ts, inc_col] physical_plan -ProjectionExec: expr=[fv1@0 as fv1, fv2@1 as fv2, lv1@2 as lv1, lv2@3 as lv2, nv1@4 as nv1, nv2@5 as nv2, rn1@6 as rn1, rn2@7 as rn2, rank1@8 as rank1, rank2@9 as rank2, dense_rank1@10 as dense_rank1, dense_rank2@11 as dense_rank2, lag1@12 as lag1, lag2@13 as lag2, lead1@14 as lead1, lead2@15 as lead2, fvr1@16 as fvr1, fvr2@17 as fvr2, lvr1@18 as lvr1, lvr2@19 as lvr2, lagr1@20 as lagr1, lagr2@21 as lagr2, leadr1@22 as leadr1, leadr2@23 as leadr2] ---GlobalLimitExec: skip=0, fetch=5 -----SortExec: fetch=5, expr=[ts@24 DESC] -------ProjectionExec: expr=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2, ts@0 as ts] ---------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }], mode=[Sorted] -----------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }], mode=[Sorted] -------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true +GlobalLimitExec: skip=0, fetch=5 +--SortExec: TopK(fetch=5), expr=[ts@0 DESC] +----ProjectionExec: expr=[ts@0 as ts, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2] +------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +--------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }], mode=[Sorted] +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true -query IIIIIIIIIIIIIIIIIIIIIIII +query IIIIIIIIIIIIIIIIIIIIIIIII SELECT + ts, FIRST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as fv1, FIRST_VALUE(inc_col) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as fv2, LAST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as lv1, @@ -2673,14 +2673,14 @@ SELECT LEAD(inc_col, -1, 1001) OVER(ORDER BY ts DESC RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS leadr1, LEAD(inc_col, 4, 1004) OVER(ORDER BY ts DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as leadr2 FROM annotated_data_finite - ORDER BY ts DESC + ORDER BY ts DESC, fv2 LIMIT 5; ---- -289 269 305 305 305 283 100 100 99 99 86 86 301 296 301 1004 305 305 301 301 1001 1002 1001 289 -289 266 305 305 305 278 99 99 99 99 86 86 296 291 296 1004 305 305 301 296 305 1002 305 286 -289 261 296 301 NULL 275 98 98 98 98 85 85 291 289 291 1004 305 305 296 291 301 305 301 283 -286 259 291 296 NULL 272 97 97 97 97 84 84 289 286 289 1004 305 305 291 289 296 301 296 278 -275 254 289 291 289 269 96 96 96 96 83 83 286 283 286 305 305 305 289 286 291 296 291 275 +264 289 266 305 305 305 278 99 99 99 99 86 86 296 291 296 1004 305 305 301 296 305 1002 305 286 +264 289 269 305 305 305 283 100 100 99 99 86 86 301 296 301 1004 305 305 301 301 1001 1002 1001 289 +262 289 261 296 301 NULL 275 98 98 98 98 85 85 291 289 291 1004 305 305 296 291 301 305 301 283 +258 286 259 291 296 NULL 272 97 97 97 97 84 84 289 286 289 1004 305 305 291 289 296 301 296 278 +254 275 254 289 291 289 269 96 96 96 96 83 83 286 283 286 305 305 305 289 286 291 296 291 275 # test_source_sorted_unbounded_preceding @@ -2712,7 +2712,7 @@ Projection: sum1, sum2, min1, min2, max1, max2, count1, count2, avg1, avg2 physical_plan ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, min1@2 as min1, min2@3 as min2, max1@4 as max1, max2@5 as max2, count1@6 as count1, count2@7 as count2, avg1@8 as avg1, avg2@9 as avg2] --GlobalLimitExec: skip=0, fetch=5 -----SortExec: fetch=5, expr=[inc_col@10 ASC NULLS LAST] +----SortExec: TopK(fetch=5), expr=[inc_col@10 ASC NULLS LAST] ------ProjectionExec: expr=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@7 as sum1, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as sum2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@8 as min1, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as min2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@9 as max1, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as max2, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@10 as count1, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@5 as count2, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@11 as avg1, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@6 as avg2, inc_col@1 as inc_col] --------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }], mode=[Sorted] ----------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }], mode=[Sorted] @@ -2765,7 +2765,7 @@ Projection: first_value1, first_value2, last_value1, last_value2, nth_value1 physical_plan ProjectionExec: expr=[first_value1@0 as first_value1, first_value2@1 as first_value2, last_value1@2 as last_value1, last_value2@3 as last_value2, nth_value1@4 as nth_value1] --GlobalLimitExec: skip=0, fetch=5 -----SortExec: fetch=5, expr=[inc_col@5 ASC NULLS LAST] +----SortExec: TopK(fetch=5), expr=[inc_col@5 ASC NULLS LAST] ------ProjectionExec: expr=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as first_value1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as first_value2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as last_value1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as last_value2, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as nth_value1, inc_col@1 as inc_col] --------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }], mode=[Sorted] ----------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }], mode=[Sorted] @@ -2814,7 +2814,7 @@ ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, count1@2 as count1, count2 ----ProjectionExec: expr=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as sum1, SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as count2, ts@0 as ts] ------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }], mode=[Sorted] --------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }], mode=[Sorted] -----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST], has_header=true +----------StreamingTableExec: partition_sizes=1, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST] query IIII @@ -2860,7 +2860,7 @@ ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, count1@2 as count1, count2 ----ProjectionExec: expr=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as sum1, SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as count2, ts@0 as ts] ------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }], mode=[Sorted] --------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }], mode=[Sorted] -----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST], has_header=true +----------StreamingTableExec: partition_sizes=1, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST] query IIII @@ -2964,7 +2964,7 @@ ProjectionExec: expr=[a@1 as a, b@2 as b, c@3 as c, SUM(annotated_data_infinite2 ------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: CurrentRow }], mode=[PartiallySorted([0, 1])] --------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)) }], mode=[Sorted] ----------------ProjectionExec: expr=[CAST(c@2 AS Int64) as CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c, a@0 as a, b@1 as b, c@2 as c, d@3 as d] -------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +------------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query IIIIIIIIIIIIIII @@ -3026,7 +3026,7 @@ Limit: skip=0, fetch=5 --------------------TableScan: annotated_data_finite2 projection=[a, b, c, d] physical_plan GlobalLimitExec: skip=0, fetch=5 ---SortExec: fetch=5, expr=[c@2 ASC NULLS LAST] +--SortExec: TopK(fetch=5), expr=[c@2 ASC NULLS LAST] ----ProjectionExec: expr=[a@1 as a, b@2 as b, c@3 as c, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@9 as sum1, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING@10 as sum2, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@15 as sum3, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING@16 as sum4, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@5 as sum5, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@6 as sum6, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@11 as sum7, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@12 as sum8, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@7 as sum9, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW@8 as sum10, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@13 as sum11, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING@14 as sum12] ------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(1)) }], mode=[Sorted] --------SortExec: expr=[d@4 ASC NULLS LAST,a@1 ASC NULLS LAST,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST] @@ -3106,7 +3106,7 @@ CoalesceBatchesExec: target_batch_size=4096 ----GlobalLimitExec: skip=0, fetch=5 ------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as rn1] --------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] -----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +----------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] # this is a negative test for asserting that window functions (other than ROW_NUMBER) # are not added to ordering equivalence @@ -3128,7 +3128,7 @@ Limit: skip=0, fetch=5 ----------TableScan: aggregate_test_100 projection=[c9] physical_plan GlobalLimitExec: skip=0, fetch=5 ---SortExec: fetch=5, expr=[sum1@1 ASC NULLS LAST,c9@0 DESC] +--SortExec: TopK(fetch=5), expr=[sum1@1 ASC NULLS LAST,c9@0 DESC] ----ProjectionExec: expr=[c9@0 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum1] ------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] --------SortExec: expr=[c9@0 DESC] @@ -3197,6 +3197,72 @@ SELECT a_new, d, rn1 FROM (SELECT d, a as a_new, 0 0 4 0 1 5 +query TT +EXPLAIN SELECT SUM(a) OVER(partition by a, b order by c) as sum1, +SUM(a) OVER(partition by b, a order by c) as sum2, + SUM(a) OVER(partition by a, d order by b) as sum3, + SUM(a) OVER(partition by d order by a) as sum4 +FROM annotated_data_infinite2; +---- +logical_plan +Projection: SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum3, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum4 +--WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----Projection: annotated_data_infinite2.a, annotated_data_infinite2.d, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +------------TableScan: annotated_data_infinite2 projection=[a, b, c, d] +physical_plan +ProjectionExec: expr=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum2, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum3, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum4] +--BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Linear] +----ProjectionExec: expr=[a@0 as a, d@3 as d, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[PartiallySorted([0])] +----------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] + +statement ok +set datafusion.execution.target_partitions = 2; + +# re-execute the same query in multi partitions. +# final plan should still be streamable +query TT +EXPLAIN SELECT SUM(a) OVER(partition by a, b order by c) as sum1, + SUM(a) OVER(partition by b, a order by c) as sum2, + SUM(a) OVER(partition by a, d order by b) as sum3, + SUM(a) OVER(partition by d order by a) as sum4 +FROM annotated_data_infinite2; +---- +logical_plan +Projection: SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum3, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum4 +--WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----Projection: annotated_data_infinite2.a, annotated_data_infinite2.d, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +------------TableScan: annotated_data_infinite2 projection=[a, b, c, d] +physical_plan +ProjectionExec: expr=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum2, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum3, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum4] +--BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Linear] +----CoalesceBatchesExec: target_batch_size=4096 +------RepartitionExec: partitioning=Hash([d@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST +--------ProjectionExec: expr=[a@0 as a, d@3 as d, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +----------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +------------CoalesceBatchesExec: target_batch_size=4096 +--------------RepartitionExec: partitioning=Hash([b@1, a@0], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST +----------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[PartiallySorted([0])] +------------------CoalesceBatchesExec: target_batch_size=4096 +--------------------RepartitionExec: partitioning=Hash([a@0, d@3], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST +----------------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +------------------------CoalesceBatchesExec: target_batch_size=4096 +--------------------------RepartitionExec: partitioning=Hash([a@0, b@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST +----------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------------------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] + +# reset the partition number 1 again +statement ok +set datafusion.execution.target_partitions = 1; + statement ok drop table annotated_data_finite2 @@ -3205,25 +3271,27 @@ drop table annotated_data_infinite2 # window3 spec is not used in window functions. # The query should still work. -query RR +query IRR SELECT - MAX(c12) OVER window1, - MIN(c12) OVER window2 as max1 + C3, + MAX(c12) OVER window1 as max1, + MIN(c12) OVER window2 as max2 FROM aggregate_test_100 WINDOW window1 AS (ORDER BY C12), window2 AS (PARTITION BY C11), window3 AS (ORDER BY C1) - ORDER BY C3 + ORDER BY C3, max2 LIMIT 5 ---- -0.970671228336 0.970671228336 -0.850672105305 0.850672105305 -0.152498292972 0.152498292972 -0.369363046006 0.369363046006 -0.56535284223 0.56535284223 +-117 0.850672105305 0.850672105305 +-117 0.970671228336 0.970671228336 +-111 0.152498292972 0.152498292972 +-107 0.369363046006 0.369363046006 +-106 0.56535284223 0.56535284223 query TT EXPLAIN SELECT + C3, MAX(c12) OVER window1 as min1, MIN(c12) OVER window2 as max1 FROM aggregate_test_100 @@ -3234,42 +3302,41 @@ EXPLAIN SELECT LIMIT 5 ---- logical_plan -Projection: min1, max1 ---Limit: skip=0, fetch=5 -----Sort: aggregate_test_100.c3 ASC NULLS LAST, fetch=5 -------Projection: MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS min1, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max1, aggregate_test_100.c3 ---------WindowAggr: windowExpr=[[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -----------Projection: aggregate_test_100.c3, aggregate_test_100.c12, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING -------------WindowAggr: windowExpr=[[MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] ---------------TableScan: aggregate_test_100 projection=[c3, c11, c12] +Limit: skip=0, fetch=5 +--Sort: aggregate_test_100.c3 ASC NULLS LAST, fetch=5 +----Projection: aggregate_test_100.c3, MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS min1, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max1 +------WindowAggr: windowExpr=[[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +--------Projection: aggregate_test_100.c3, aggregate_test_100.c12, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING +----------WindowAggr: windowExpr=[[MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +------------TableScan: aggregate_test_100 projection=[c3, c11, c12] physical_plan -ProjectionExec: expr=[min1@0 as min1, max1@1 as max1] ---GlobalLimitExec: skip=0, fetch=5 -----SortExec: fetch=5, expr=[c3@2 ASC NULLS LAST] -------ProjectionExec: expr=[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as min1, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@2 as max1, c3@0 as c3] ---------BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow }], mode=[Sorted] -----------SortExec: expr=[c12@1 ASC NULLS LAST] -------------ProjectionExec: expr=[c3@0 as c3, c12@2 as c12, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING] ---------------WindowAggExec: wdw=[MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }] -----------------SortExec: expr=[c11@1 ASC NULLS LAST] -------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3, c11, c12], has_header=true +GlobalLimitExec: skip=0, fetch=5 +--SortExec: TopK(fetch=5), expr=[c3@0 ASC NULLS LAST] +----ProjectionExec: expr=[c3@0 as c3, MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as min1, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@2 as max1] +------BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------SortExec: expr=[c12@1 ASC NULLS LAST] +----------ProjectionExec: expr=[c3@0 as c3, c12@2 as c12, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING] +------------WindowAggExec: wdw=[MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }] +--------------SortExec: expr=[c11@1 ASC NULLS LAST] +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3, c11, c12], has_header=true # window1 spec is used multiple times under different aggregations. # The query should still work. -query RR +query IRR SELECT + C3, MAX(c12) OVER window1 as min1, MIN(c12) OVER window1 as max1 FROM aggregate_test_100 WINDOW window1 AS (ORDER BY C12) - ORDER BY C3 + ORDER BY C3, min1 LIMIT 5 ---- -0.970671228336 0.014793053078 -0.850672105305 0.014793053078 -0.152498292972 0.014793053078 -0.369363046006 0.014793053078 -0.56535284223 0.014793053078 +-117 0.850672105305 0.014793053078 +-117 0.970671228336 0.014793053078 +-111 0.152498292972 0.014793053078 +-107 0.369363046006 0.014793053078 +-106 0.56535284223 0.014793053078 query TT EXPLAIN SELECT @@ -3290,7 +3357,7 @@ Projection: min1, max1 physical_plan ProjectionExec: expr=[min1@0 as min1, max1@1 as max1] --GlobalLimitExec: skip=0, fetch=5 -----SortExec: fetch=5, expr=[c3@2 ASC NULLS LAST] +----SortExec: TopK(fetch=5), expr=[c3@2 ASC NULLS LAST] ------ProjectionExec: expr=[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as min1, MIN(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as max1, c3@0 as c3] --------BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow }, MIN(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MIN(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow }], mode=[Sorted] ----------SortExec: expr=[c12@1 ASC NULLS LAST] @@ -3315,3 +3382,492 @@ SELECT window1 AS (ORDER BY C3) ORDER BY C3 LIMIT 5 + +# Create a source where there is multiple orderings. +statement ok +CREATE EXTERNAL TABLE multiple_ordered_table ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# Create an unbounded source where there is multiple orderings. +statement ok +CREATE UNBOUNDED EXTERNAL TABLE multiple_ordered_table_inf ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# All of the window execs in the physical plan should work in the +# sorted mode. +query TT +EXPLAIN SELECT MIN(d) OVER(ORDER BY c ASC) as min1, + MAX(d) OVER(PARTITION BY b, a ORDER BY c ASC) as max1 +FROM multiple_ordered_table +---- +logical_plan +Projection: MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS min1, MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS max1 +--WindowAggr: windowExpr=[[MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----Projection: multiple_ordered_table.c, multiple_ordered_table.d, MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +------WindowAggr: windowExpr=[[MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +--------TableScan: multiple_ordered_table projection=[a, b, c, d] +physical_plan +ProjectionExec: expr=[MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as min1, MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as max1] +--BoundedWindowAggExec: wdw=[MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----ProjectionExec: expr=[c@2 as c, d@3 as d, MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +------BoundedWindowAggExec: wdw=[MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true + +query TT +EXPLAIN SELECT MAX(c) OVER(PARTITION BY d ORDER BY c ASC) as max_c +FROM( + SELECT * + FROM multiple_ordered_table + WHERE d=0) +---- +logical_plan +Projection: MAX(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS max_c +--WindowAggr: windowExpr=[[MAX(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----Filter: multiple_ordered_table.d = Int32(0) +------TableScan: multiple_ordered_table projection=[c, d], partial_filters=[multiple_ordered_table.d = Int32(0)] +physical_plan +ProjectionExec: expr=[MAX(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as max_c] +--BoundedWindowAggExec: wdw=[MAX(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----CoalesceBatchesExec: target_batch_size=4096 +------FilterExec: d@1 = 0 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +query TT +explain SELECT SUM(d) OVER(PARTITION BY c ORDER BY a ASC) +FROM multiple_ordered_table; +---- +logical_plan +Projection: SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +--WindowAggr: windowExpr=[[SUM(CAST(multiple_ordered_table.d AS Int64)) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----TableScan: multiple_ordered_table projection=[a, c, d] +physical_plan +ProjectionExec: expr=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +--BoundedWindowAggExec: wdw=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +query TT +explain SELECT SUM(d) OVER(PARTITION BY c, a ORDER BY b ASC) +FROM multiple_ordered_table; +---- +logical_plan +Projection: SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +--WindowAggr: windowExpr=[[SUM(CAST(multiple_ordered_table.d AS Int64)) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----TableScan: multiple_ordered_table projection=[a, b, c, d] +physical_plan +ProjectionExec: expr=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +--BoundedWindowAggExec: wdw=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true + +query I +SELECT SUM(d) OVER(PARTITION BY c, a ORDER BY b ASC) +FROM multiple_ordered_table +LIMIT 5; +---- +0 +2 +0 +0 +1 + +# simple window query +query II +select sum(1) over() x, sum(1) over () y +---- +1 1 + +# NTH_VALUE requirement is c DESC, However existing ordering is c ASC +# if we reverse window expression: "NTH_VALUE(c, 2) OVER(order by c DESC ) as nv1" +# as "NTH_VALUE(c, -2) OVER(order by c ASC RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) as nv1" +# Please note that: "NTH_VALUE(c, 2) OVER(order by c DESC ) as nv1" is same with +# "NTH_VALUE(c, 2) OVER(order by c DESC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as nv1" " +# we can produce same result without re-sorting the table. +# Unfortunately since window expression names are string, this change is not seen the plan (we do not do string manipulation). +# TODO: Reflect window expression reversal in the plans. +query TT +EXPLAIN SELECT c, NTH_VALUE(c, 2) OVER(order by c DESC) as nv1 + FROM multiple_ordered_table + ORDER BY c ASC + LIMIT 5 +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: multiple_ordered_table.c ASC NULLS LAST, fetch=5 +----Projection: multiple_ordered_table.c, NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS nv1 +------WindowAggr: windowExpr=[[NTH_VALUE(multiple_ordered_table.c, Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +--------TableScan: multiple_ordered_table projection=[c] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--ProjectionExec: expr=[c@0 as c, NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as nv1] +----WindowAggExec: wdw=[NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int32(NULL)) }] +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +query II +SELECT c, NTH_VALUE(c, 2) OVER(order by c DESC) as nv1 + FROM multiple_ordered_table + ORDER BY c ASC + LIMIT 5 +---- +0 98 +1 98 +2 98 +3 98 +4 98 + +query II +SELECT c, NTH_VALUE(c, 2) OVER(order by c DESC) as nv1 + FROM multiple_ordered_table + ORDER BY c DESC + LIMIT 5 +---- +99 NULL +98 98 +97 98 +96 98 +95 98 + +statement ok +set datafusion.execution.target_partitions = 2; + +# source is ordered by [a ASC, b ASC], [c ASC] +# after sort preserving repartition and sort preserving merge +# we should still have the orderings [a ASC, b ASC], [c ASC]. +query TT +EXPLAIN SELECT *, + AVG(d) OVER sliding_window AS avg_d +FROM multiple_ordered_table_inf +WINDOW sliding_window AS ( + PARTITION BY d + ORDER BY a RANGE 10 PRECEDING +) +ORDER BY c +---- +logical_plan +Sort: multiple_ordered_table_inf.c ASC NULLS LAST +--Projection: multiple_ordered_table_inf.a0, multiple_ordered_table_inf.a, multiple_ordered_table_inf.b, multiple_ordered_table_inf.c, multiple_ordered_table_inf.d, AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW AS avg_d +----WindowAggr: windowExpr=[[AVG(CAST(multiple_ordered_table_inf.d AS Float64)) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW]] +------TableScan: multiple_ordered_table_inf projection=[a0, a, b, c, d] +physical_plan +SortPreservingMergeExec: [c@3 ASC NULLS LAST] +--ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW@5 as avg_d] +----BoundedWindowAggExec: wdw=[AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW: Ok(Field { name: "AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: CurrentRow }], mode=[Linear] +------CoalesceBatchesExec: target_batch_size=4096 +--------RepartitionExec: partitioning=Hash([d@4], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST,b@2 ASC NULLS LAST +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST] + +# CTAS with NTILE function +statement ok +CREATE TABLE new_table AS SELECT NTILE(2) OVER(ORDER BY c1) AS ntile_2 FROM aggregate_test_100; + +statement ok +DROP TABLE new_table; + +statement ok +CREATE TABLE t1 (a int) AS VALUES (1), (2), (3); + +query I +SELECT NTILE(9223377) OVER(ORDER BY a) FROM t1; +---- +1 +2 +3 + +query I +SELECT NTILE(9223372036854775809) OVER(ORDER BY a) FROM t1; +---- +1 +2 +3 + +query error DataFusion error: Execution error: NTILE requires a positive integer +SELECT NTILE(-922337203685477580) OVER(ORDER BY a) FROM t1; + +query error DataFusion error: Execution error: Table 't' doesn't exist\. +DROP TABLE t; + +# NTILE with PARTITION BY, those tests from duckdb: https://github.com/duckdb/duckdb/blob/main/test/sql/window/test_ntile.test +statement ok +CREATE TABLE score_board (team_name VARCHAR, player VARCHAR, score INTEGER) as VALUES + ('Mongrels', 'Apu', 350), + ('Mongrels', 'Ned', 666), + ('Mongrels', 'Meg', 1030), + ('Mongrels', 'Burns', 1270), + ('Simpsons', 'Homer', 1), + ('Simpsons', 'Lisa', 710), + ('Simpsons', 'Marge', 990), + ('Simpsons', 'Bart', 2010) + +query TTII +SELECT + team_name, + player, + score, + NTILE(2) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY team_name, score; +---- +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Mongrels Meg 1030 2 +Mongrels Burns 1270 2 +Simpsons Homer 1 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 2 +Simpsons Bart 2010 2 + +query TTII +SELECT + team_name, + player, + score, + NTILE(2) OVER (ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY score; +---- +Simpsons Homer 1 1 +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 2 +Mongrels Meg 1030 2 +Mongrels Burns 1270 2 +Simpsons Bart 2010 2 + +query TTII +SELECT + team_name, + player, + score, + NTILE(1000) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY team_name, score; +---- +Mongrels Apu 350 1 +Mongrels Ned 666 2 +Mongrels Meg 1030 3 +Mongrels Burns 1270 4 +Simpsons Homer 1 1 +Simpsons Lisa 710 2 +Simpsons Marge 990 3 +Simpsons Bart 2010 4 + +query TTII +SELECT + team_name, + player, + score, + NTILE(1) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY team_name, score; +---- +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Mongrels Meg 1030 1 +Mongrels Burns 1270 1 +Simpsons Homer 1 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 1 +Simpsons Bart 2010 1 + +# incorrect number of parameters for ntile +query error DataFusion error: Execution error: NTILE requires a positive integer, but finds NULL +SELECT + NTILE(NULL) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +query error DataFusion error: Execution error: NTILE requires a positive integer +SELECT + NTILE(-1) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +query error DataFusion error: Execution error: NTILE requires a positive integer +SELECT + NTILE(0) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE() OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE(1,2) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE(1,2,3) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE(1,2,3,4) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement ok +DROP TABLE score_board; + +# Regularize RANGE frame +query error DataFusion error: Error during planning: RANGE requires exactly one ORDER BY column +select a, + rank() over (order by a, a + 1 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a + +query II +select a, + rank() over (order by a RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 2 + +query error DataFusion error: Error during planning: RANGE requires exactly one ORDER BY column +select a, + rank() over (RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a + +query II +select a, + rank() over (order by a, a + 1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 2 + +query II +select a, + rank() over (order by a RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 2 + +query II +select a, + rank() over (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1 + +query I +select rank() over (RANGE between UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q; +---- +1 +1 + +query II +select a, + rank() over (order by 1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1 + +query II +select a, + rank() over (order by null RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1 + +# support scalar value in ORDER BY +query I +select rank() over (order by 1) rnk from (select 1 a union all select 2 a) x +---- +1 +1 + +# support scalar value in ORDER BY +query I +select dense_rank() over () rnk from (select 1 a union all select 2 a) x +---- +1 +1 + +# support scalar value in both ORDER BY and PARTITION BY, RANK function +query IIIIII +select rank() over (partition by 1 order by 1) rnk, + rank() over (partition by a, 1 order by 1) rnk1, + rank() over (partition by a, 1 order by a, 1) rnk2, + rank() over (partition by 1) rnk3, + rank() over (partition by null) rnk4, + rank() over (partition by 1, null, a) rnk5 +from (select 1 a union all select 2 a) x +---- +1 1 1 1 1 1 +1 1 1 1 1 1 + +# support scalar value in both ORDER BY and PARTITION BY, ROW_NUMBER function +query IIIIII +select row_number() over (partition by 1 order by 1) rn, + row_number() over (partition by a, 1 order by 1) rn1, + row_number() over (partition by a, 1 order by a, 1) rn2, + row_number() over (partition by 1) rn3, + row_number() over (partition by null) rn4, + row_number() over (partition by 1, null, a) rn5 +from (select 1 a union all select 2 a) x; +---- +1 1 1 1 1 1 +2 1 1 2 2 1 + +# when partition by expression is empty row number result will be unique. +query TII +SELECT * +FROM (SELECT c1, c2, ROW_NUMBER() OVER() as rn + FROM aggregate_test_100 + LIMIT 5) +GROUP BY rn +ORDER BY rn; +---- +c 2 1 +d 5 2 +b 1 3 +a 1 4 +b 5 5 + +# when partition by expression is constant row number result will be unique. +query TII +SELECT * +FROM (SELECT c1, c2, ROW_NUMBER() OVER(PARTITION BY 3) as rn + FROM aggregate_test_100 + LIMIT 5) +GROUP BY rn +ORDER BY rn; +---- +c 2 1 +d 5 2 +b 1 3 +a 1 4 +b 5 5 + +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression aggregate_test_100.c1 could not be resolved from available columns: rn +SELECT * +FROM (SELECT c1, c2, ROW_NUMBER() OVER(PARTITION BY c1) as rn + FROM aggregate_test_100 + LIMIT 5) +GROUP BY rn +ORDER BY rn; diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 2f7816c6488a..0a9a6e8dd12b 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -18,9 +18,9 @@ [package] name = "datafusion-substrait" description = "DataFusion Substrait Producer and Consumer" +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -30,12 +30,12 @@ rust-version = "1.70" [dependencies] async-recursion = "1.0" chrono = { workspace = true } -datafusion = { version = "31.0.0", path = "../core" } -itertools = "0.11" -object_store = "0.7.0" -prost = "0.11" -prost-types = "0.11" -substrait = "0.14.0" +datafusion = { workspace = true } +itertools = { workspace = true } +object_store = { workspace = true } +prost = "0.12" +prost-types = "0.12" +substrait = "0.21.0" tokio = "1.17" [features] diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 32b8f8ea547f..a4ec3e7722a2 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -17,24 +17,30 @@ use async_recursion::async_recursion; use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; -use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef}; +use datafusion::common::{ + not_impl_err, substrait_datafusion_err, substrait_err, DFField, DFSchema, DFSchemaRef, +}; + +use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ - aggregate_function, window_function::find_df_window_func, BinaryExpr, - BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, + aggregate_function, expr::find_df_window_func, BinaryExpr, BuiltinScalarFunction, + Case, Expr, LogicalPlan, Operator, }; use datafusion::logical_expr::{ - expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, WindowFrameBound, - WindowFrameUnits, + expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, + Repartition, Subquery, WindowFrameBound, WindowFrameUnits, }; use datafusion::prelude::JoinType; use datafusion::sql::TableReference; use datafusion::{ error::{DataFusionError, Result}, - optimizer::utils::split_conjunction, + logical_expr::utils::split_conjunction, prelude::{Column, SessionContext}, scalar::ScalarValue, }; -use substrait::proto::expression::{Literal, ScalarFunction}; +use substrait::proto::exchange_rel::ExchangeKind; +use substrait::proto::expression::subquery::SubqueryType; +use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; use substrait::proto::{ aggregate_function::AggregationInvocation, expression::{ @@ -56,7 +62,7 @@ use substrait::proto::{ use substrait::proto::{FunctionArgument, SortField}; use datafusion::common::plan_err; -use datafusion::logical_expr::expr::{InList, Sort}; +use datafusion::logical_expr::expr::{InList, InSubquery, Sort}; use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; @@ -71,12 +77,7 @@ use crate::variation_const::{ enum ScalarFunctionType { Builtin(BuiltinScalarFunction), Op(Operator), - /// [Expr::Not] - Not, - /// [Expr::Like] Used for filtering rows based on the given wildcard pattern. Case sensitive - Like, - /// [Expr::Like] Case insensitive operator counterpart of `Like` - ILike, + Expr(BuiltinExprBuilder), } pub fn name_to_op(name: &str) -> Result { @@ -121,17 +122,61 @@ fn scalar_function_type_from_str(name: &str) -> Result { return Ok(ScalarFunctionType::Builtin(fun)); } - match name { - "not" => Ok(ScalarFunctionType::Not), - "like" => Ok(ScalarFunctionType::Like), - "ilike" => Ok(ScalarFunctionType::ILike), - others => not_impl_err!("Unsupported function name: {others:?}"), + if let Some(builder) = BuiltinExprBuilder::try_from_name(name) { + return Ok(ScalarFunctionType::Expr(builder)); } + + not_impl_err!("Unsupported function name: {name:?}") +} + +fn split_eq_and_noneq_join_predicate_with_nulls_equality( + filter: &Expr, +) -> (Vec<(Column, Column)>, bool, Option) { + let exprs = split_conjunction(filter); + + let mut accum_join_keys: Vec<(Column, Column)> = vec![]; + let mut accum_filters: Vec = vec![]; + let mut nulls_equal_nulls = false; + + for expr in exprs { + match expr { + Expr::BinaryExpr(binary_expr) => match binary_expr { + x @ (BinaryExpr { + left, + op: Operator::Eq, + right, + } + | BinaryExpr { + left, + op: Operator::IsNotDistinctFrom, + right, + }) => { + nulls_equal_nulls = match x.op { + Operator::Eq => false, + Operator::IsNotDistinctFrom => true, + _ => unreachable!(), + }; + + match (left.as_ref(), right.as_ref()) { + (Expr::Column(l), Expr::Column(r)) => { + accum_join_keys.push((l.clone(), r.clone())); + } + _ => accum_filters.push(expr.clone()), + } + } + _ => accum_filters.push(expr.clone()), + }, + _ => accum_filters.push(expr.clone()), + } + } + + let join_filter = accum_filters.into_iter().reduce(Expr::and); + (accum_join_keys, nulls_equal_nulls, join_filter) } /// Convert Substrait Plan to DataFusion DataFrame pub async fn from_substrait_plan( - ctx: &mut SessionContext, + ctx: &SessionContext, plan: &Plan, ) -> Result { // Register function extension @@ -173,7 +218,7 @@ pub async fn from_substrait_plan( /// Convert Substrait Rel to DataFusion DataFrame #[async_recursion] pub async fn from_substrait_rel( - ctx: &mut SessionContext, + ctx: &SessionContext, rel: &Rel, extensions: &HashMap, ) -> Result { @@ -186,7 +231,8 @@ pub async fn from_substrait_rel( let mut exprs: Vec = vec![]; for e in &p.expressions { let x = - from_substrait_rex(e, input.clone().schema(), extensions).await?; + from_substrait_rex(ctx, e, input.clone().schema(), extensions) + .await?; // if the expression is WindowFunction, wrap in a Window relation // before returning and do not add to list of this Projection's expression list // otherwise, add expression to the Projection's expression list @@ -212,7 +258,8 @@ pub async fn from_substrait_rel( ); if let Some(condition) = filter.condition.as_ref() { let expr = - from_substrait_rex(condition, input.schema(), extensions).await?; + from_substrait_rex(ctx, condition, input.schema(), extensions) + .await?; input.filter(expr.as_ref().clone())?.build() } else { not_impl_err!("Filter without an condition is not valid") @@ -227,8 +274,13 @@ pub async fn from_substrait_rel( from_substrait_rel(ctx, input, extensions).await?, ); let offset = fetch.offset as usize; - let count = fetch.count as usize; - input.limit(offset, Some(count))?.build() + // Since protobuf can't directly distinguish `None` vs `0` `None` is encoded as `MAX` + let count = if fetch.count as usize == usize::MAX { + None + } else { + Some(fetch.count as usize) + }; + input.limit(offset, count)?.build() } else { not_impl_err!("Fetch without an input is not valid") } @@ -239,7 +291,8 @@ pub async fn from_substrait_rel( from_substrait_rel(ctx, input, extensions).await?, ); let sorts = - from_substrait_sorts(&sort.sorts, input.schema(), extensions).await?; + from_substrait_sorts(ctx, &sort.sorts, input.schema(), extensions) + .await?; input.sort(sorts)?.build() } else { not_impl_err!("Sort without an input is not valid") @@ -257,7 +310,8 @@ pub async fn from_substrait_rel( 1 => { for e in &agg.groupings[0].grouping_expressions { let x = - from_substrait_rex(e, input.schema(), extensions).await?; + from_substrait_rex(ctx, e, input.schema(), extensions) + .await?; group_expr.push(x.as_ref().clone()); } } @@ -266,8 +320,13 @@ pub async fn from_substrait_rel( for grouping in &agg.groupings { let mut grouping_set = vec![]; for e in &grouping.grouping_expressions { - let x = from_substrait_rex(e, input.schema(), extensions) - .await?; + let x = from_substrait_rex( + ctx, + e, + input.schema(), + extensions, + ) + .await?; grouping_set.push(x.as_ref().clone()); } grouping_sets.push(grouping_set); @@ -285,7 +344,7 @@ pub async fn from_substrait_rel( for m in &agg.measures { let filter = match &m.filter { Some(fil) => Some(Box::new( - from_substrait_rex(fil, input.schema(), extensions) + from_substrait_rex(ctx, fil, input.schema(), extensions) .await? .as_ref() .clone(), @@ -308,6 +367,7 @@ pub async fn from_substrait_rel( _ => false, }; from_substrait_agg_func( + ctx, f, input.schema(), extensions, @@ -331,7 +391,13 @@ pub async fn from_substrait_rel( } } Some(RelType::Join(join)) => { - let left = LogicalPlanBuilder::from( + if join.post_join_filter.is_some() { + return not_impl_err!( + "JoinRel with post_join_filter is not yet supported" + ); + } + + let left: LogicalPlanBuilder = LogicalPlanBuilder::from( from_substrait_rel(ctx, join.left.as_ref().unwrap(), extensions).await?, ); let right = LogicalPlanBuilder::from( @@ -341,67 +407,43 @@ pub async fn from_substrait_rel( // The join condition expression needs full input schema and not the output schema from join since we lose columns from // certain join types such as semi and anti joins let in_join_schema = left.schema().join(right.schema())?; - // Parse post join filter if exists - let join_filter = match &join.post_join_filter { - Some(filter) => { - let parsed_filter = - from_substrait_rex(filter, &in_join_schema, extensions).await?; - Some(parsed_filter.as_ref().clone()) - } - None => None, - }; + // If join expression exists, parse the `on` condition expression, build join and return - // Otherwise, build join with koin filter, without join keys + // Otherwise, build join with only the filter, without join keys match &join.expression.as_ref() { Some(expr) => { - let on = - from_substrait_rex(expr, &in_join_schema, extensions).await?; - let predicates = split_conjunction(&on); - // TODO: collect only one null_eq_null - let join_exprs: Vec<(Column, Column, bool)> = predicates - .iter() - .map(|p| match p { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => match op { - Operator::Eq => Ok((l.clone(), r.clone(), false)), - Operator::IsNotDistinctFrom => { - Ok((l.clone(), r.clone(), true)) - } - _ => plan_err!("invalid join condition op"), - }, - _ => plan_err!("invalid join condition expression"), - } - } - _ => plan_err!( - "Non-binary expression is not supported in join condition" - ), - }) - .collect::>>()?; - let (left_cols, right_cols, null_eq_nulls): (Vec<_>, Vec<_>, Vec<_>) = - itertools::multiunzip(join_exprs); + let on = from_substrait_rex(ctx, expr, &in_join_schema, extensions) + .await?; + // The join expression can contain both equal and non-equal ops. + // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. + // So we extract each part as follows: + // - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector + // - Otherwise we add the expression to join_filter (use conjunction if filter already exists) + let (join_ons, nulls_equal_nulls, join_filter) = + split_eq_and_noneq_join_predicate_with_nulls_equality(&on); + let (left_cols, right_cols): (Vec<_>, Vec<_>) = + itertools::multiunzip(join_ons); left.join_detailed( right.build()?, join_type, (left_cols, right_cols), join_filter, - null_eq_nulls[0], + nulls_equal_nulls, )? .build() } - None => match &join_filter { - Some(_) => left - .join( - right.build()?, - join_type, - (Vec::::new(), Vec::::new()), - join_filter, - )? - .build(), - None => plan_err!("Join without join keys require a valid filter"), - }, + None => plan_err!("JoinRel without join condition is not allowed"), } } + Some(RelType::Cross(cross)) => { + let left: LogicalPlanBuilder = LogicalPlanBuilder::from( + from_substrait_rel(ctx, cross.left.as_ref().unwrap(), extensions).await?, + ); + let right = + from_substrait_rel(ctx, cross.right.as_ref().unwrap(), extensions) + .await?; + left.cross_join(right)?.build() + } Some(RelType::Read(read)) => match &read.as_ref().read_type { Some(ReadType::NamedTable(nt)) => { let table_reference = match nt.names.len() { @@ -456,8 +498,8 @@ pub async fn from_substrait_rel( } _ => not_impl_err!("Only NamedTable reads are supported"), }, - Some(RelType::Set(set)) => match set_rel::SetOp::from_i32(set.op) { - Some(set_op) => match set_op { + Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) { + Ok(set_op) => match set_op { set_rel::SetOp::UnionAll => { if !set.inputs.is_empty() { let mut union_builder = Ok(LogicalPlanBuilder::from( @@ -474,13 +516,11 @@ pub async fn from_substrait_rel( } _ => not_impl_err!("Unsupported set operator: {set_op:?}"), }, - None => not_impl_err!("Invalid set operation type None"), + Err(e) => not_impl_err!("Invalid set operation type {}: {e}", set.op), }, Some(RelType::ExtensionLeaf(extension)) => { let Some(ext_detail) = &extension.detail else { - return Err(DataFusionError::Substrait( - "Unexpected empty detail in ExtensionLeafRel".to_string(), - )); + return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); }; let plan = ctx .state() @@ -490,18 +530,16 @@ pub async fn from_substrait_rel( } Some(RelType::ExtensionSingle(extension)) => { let Some(ext_detail) = &extension.detail else { - return Err(DataFusionError::Substrait( - "Unexpected empty detail in ExtensionSingleRel".to_string(), - )); + return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); }; let plan = ctx .state() .serializer_registry() .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; let Some(input_rel) = &extension.input else { - return Err(DataFusionError::Substrait( - "ExtensionSingleRel doesn't contains input rel. Try use ExtensionLeafRel instead".to_string() - )); + return substrait_err!( + "ExtensionSingleRel doesn't contains input rel. Try use ExtensionLeafRel instead" + ); }; let input_plan = from_substrait_rel(ctx, input_rel, extensions).await?; let plan = plan.from_template(&plan.expressions(), &[input_plan]); @@ -509,9 +547,7 @@ pub async fn from_substrait_rel( } Some(RelType::ExtensionMulti(extension)) => { let Some(ext_detail) = &extension.detail else { - return Err(DataFusionError::Substrait( - "Unexpected empty detail in ExtensionSingleRel".to_string(), - )); + return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); }; let plan = ctx .state() @@ -525,12 +561,51 @@ pub async fn from_substrait_rel( let plan = plan.from_template(&plan.expressions(), &inputs); Ok(LogicalPlan::Extension(Extension { node: plan })) } + Some(RelType::Exchange(exchange)) => { + let Some(input) = exchange.input.as_ref() else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + let input = Arc::new(from_substrait_rel(ctx, input, extensions).await?); + + let Some(exchange_kind) = &exchange.exchange_kind else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let partitioning_scheme = match exchange_kind { + ExchangeKind::ScatterByFields(scatter_fields) => { + let mut partition_columns = vec![]; + let input_schema = input.schema(); + for field_ref in &scatter_fields.fields { + let column = + from_substrait_field_reference(field_ref, input_schema)?; + partition_columns.push(column); + } + Partitioning::Hash( + partition_columns, + exchange.partition_count as usize, + ) + } + ExchangeKind::RoundRobin(_) => { + Partitioning::RoundRobinBatch(exchange.partition_count as usize) + } + ExchangeKind::SingleTarget(_) + | ExchangeKind::MultiTarget(_) + | ExchangeKind::Broadcast(_) => { + return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); + } + }; + Ok(LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + })) + } _ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type), } } fn from_substrait_jointype(join_type: i32) -> Result { - if let Some(substrait_join_type) = join_rel::JoinType::from_i32(join_type) { + if let Ok(substrait_join_type) = join_rel::JoinType::try_from(join_type) { match substrait_join_type { join_rel::JoinType::Inner => Ok(JoinType::Inner), join_rel::JoinType::Left => Ok(JoinType::Left), @@ -547,18 +622,20 @@ fn from_substrait_jointype(join_type: i32) -> Result { /// Convert Substrait Sorts to DataFusion Exprs pub async fn from_substrait_sorts( + ctx: &SessionContext, substrait_sorts: &Vec, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { let mut sorts: Vec = vec![]; for s in substrait_sorts { - let expr = from_substrait_rex(s.expr.as_ref().unwrap(), input_schema, extensions) - .await?; + let expr = + from_substrait_rex(ctx, s.expr.as_ref().unwrap(), input_schema, extensions) + .await?; let asc_nullfirst = match &s.sort_kind { Some(k) => match k { Direction(d) => { - let Some(direction) = SortDirection::from_i32(*d) else { + let Ok(direction) = SortDirection::try_from(*d) else { return not_impl_err!( "Unsupported Substrait SortDirection value {d}" ); @@ -595,13 +672,14 @@ pub async fn from_substrait_sorts( /// Convert Substrait Expressions to DataFusion Exprs pub async fn from_substrait_rex_vec( + ctx: &SessionContext, exprs: &Vec, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { let mut expressions: Vec = vec![]; for expr in exprs { - let expression = from_substrait_rex(expr, input_schema, extensions).await?; + let expression = from_substrait_rex(ctx, expr, input_schema, extensions).await?; expressions.push(expression.as_ref().clone()); } Ok(expressions) @@ -609,6 +687,7 @@ pub async fn from_substrait_rex_vec( /// Convert Substrait FunctionArguments to DataFusion Exprs pub async fn from_substriat_func_args( + ctx: &SessionContext, arguments: &Vec, input_schema: &DFSchema, extensions: &HashMap, @@ -617,7 +696,7 @@ pub async fn from_substriat_func_args( for arg in arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(e, input_schema, extensions).await + from_substrait_rex(ctx, e, input_schema, extensions).await } _ => { not_impl_err!("Aggregated function argument non-Value type not supported") @@ -630,6 +709,7 @@ pub async fn from_substriat_func_args( /// Convert Substrait AggregateFunction to DataFusion Expr pub async fn from_substrait_agg_func( + ctx: &SessionContext, f: &AggregateFunction, input_schema: &DFSchema, extensions: &HashMap, @@ -641,7 +721,7 @@ pub async fn from_substrait_agg_func( for arg in &f.arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(e, input_schema, extensions).await + from_substrait_rex(ctx, e, input_schema, extensions).await } _ => { not_impl_err!("Aggregated function argument non-Value type not supported") @@ -650,28 +730,36 @@ pub async fn from_substrait_agg_func( args.push(arg_expr?.as_ref().clone()); } - let fun = match extensions.get(&f.function_reference) { - Some(function_name) => { - aggregate_function::AggregateFunction::from_str(function_name) - } - None => not_impl_err!( - "Aggregated function not found: function anchor = {:?}", + let Some(function_name) = extensions.get(&f.function_reference) else { + return plan_err!( + "Aggregate function not registered: function anchor = {:?}", f.function_reference - ), + ); }; - Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction { - fun: fun.unwrap(), - args, - distinct, - filter, - order_by, - }))) + // try udaf first, then built-in aggr fn. + if let Ok(fun) = ctx.udaf(function_name) { + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by), + ))) + } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) + { + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new(fun, args, distinct, filter, order_by), + ))) + } else { + not_impl_err!( + "Aggregated function {} is not supported: function anchor = {:?}", + function_name, + f.function_reference + ) + } } /// Convert Substrait Rex to DataFusion Expr #[async_recursion] pub async fn from_substrait_rex( + ctx: &SessionContext, e: &Expression, input_schema: &DFSchema, extensions: &HashMap, @@ -682,37 +770,24 @@ pub async fn from_substrait_rex( let substrait_list = s.options.as_ref(); Ok(Arc::new(Expr::InList(InList { expr: Box::new( - from_substrait_rex(substrait_expr, input_schema, extensions) + from_substrait_rex(ctx, substrait_expr, input_schema, extensions) .await? .as_ref() .clone(), ), - list: from_substrait_rex_vec(substrait_list, input_schema, extensions) - .await?, + list: from_substrait_rex_vec( + ctx, + substrait_list, + input_schema, + extensions, + ) + .await?, negated: false, }))) } - Some(RexType::Selection(field_ref)) => match &field_ref.reference_type { - Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { - Some(StructField(x)) => match &x.child.as_ref() { - Some(_) => not_impl_err!( - "Direct reference StructField with child is not supported" - ), - None => { - let column = - input_schema.field(x.field as usize).qualified_column(); - Ok(Arc::new(Expr::Column(Column { - relation: column.relation, - name: column.name, - }))) - } - }, - _ => not_impl_err!( - "Direct reference with types other than StructField is not supported" - ), - }, - _ => not_impl_err!("unsupported field ref type"), - }, + Some(RexType::Selection(field_ref)) => Ok(Arc::new( + from_substrait_field_reference(field_ref, input_schema)?, + )), Some(RexType::IfThen(if_then)) => { // Parse `ifs` // If the first element does not have a `then` part, then we can assume it's a base expression @@ -724,6 +799,7 @@ pub async fn from_substrait_rex( if if_expr.then.is_none() { expr = Some(Box::new( from_substrait_rex( + ctx, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -738,6 +814,7 @@ pub async fn from_substrait_rex( when_then_expr.push(( Box::new( from_substrait_rex( + ctx, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -748,6 +825,7 @@ pub async fn from_substrait_rex( ), Box::new( from_substrait_rex( + ctx, if_expr.then.as_ref().unwrap(), input_schema, extensions, @@ -761,7 +839,7 @@ pub async fn from_substrait_rex( // Parse `else` let else_expr = match &if_then.r#else { Some(e) => Some(Box::new( - from_substrait_rex(e, input_schema, extensions) + from_substrait_rex(ctx, e, input_schema, extensions) .await? .as_ref() .clone(), @@ -788,7 +866,7 @@ pub async fn from_substrait_rex( for arg in &f.arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(e, input_schema, extensions).await + from_substrait_rex(ctx, e, input_schema, extensions).await } _ => not_impl_err!( "Aggregated function argument non-Value type not supported" @@ -796,10 +874,9 @@ pub async fn from_substrait_rex( }; args.push(arg_expr?.as_ref().clone()); } - Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction { - fun, - args, - }))) + Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction::new( + fun, args, + )))) } ScalarFunctionType::Op(op) => { if f.arguments.len() != 2 { @@ -814,14 +891,14 @@ pub async fn from_substrait_rex( (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => { Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { left: Box::new( - from_substrait_rex(l, input_schema, extensions) + from_substrait_rex(ctx, l, input_schema, extensions) .await? .as_ref() .clone(), ), op, right: Box::new( - from_substrait_rex(r, input_schema, extensions) + from_substrait_rex(ctx, r, input_schema, extensions) .await? .as_ref() .clone(), @@ -833,28 +910,8 @@ pub async fn from_substrait_rex( ), } } - ScalarFunctionType::Not => { - let arg = f.arguments.first().ok_or_else(|| { - DataFusionError::Substrait( - "expect one argument for `NOT` expr".to_string(), - ) - })?; - match &arg.arg_type { - Some(ArgType::Value(e)) => { - let expr = from_substrait_rex(e, input_schema, extensions) - .await? - .as_ref() - .clone(); - Ok(Arc::new(Expr::Not(Box::new(expr)))) - } - _ => not_impl_err!("Invalid arguments for Not expression"), - } - } - ScalarFunctionType::Like => { - make_datafusion_like(false, f, input_schema, extensions).await - } - ScalarFunctionType::ILike => { - make_datafusion_like(true, f, input_schema, extensions).await + ScalarFunctionType::Expr(builder) => { + builder.build(ctx, f, input_schema, extensions).await } } } @@ -866,6 +923,7 @@ pub async fn from_substrait_rex( Some(output_type) => Ok(Arc::new(Expr::Cast(Cast::new( Box::new( from_substrait_rex( + ctx, cast.as_ref().input.as_ref().unwrap().as_ref(), input_schema, extensions, @@ -876,9 +934,7 @@ pub async fn from_substrait_rex( ), from_substrait_type(output_type)?, )))), - None => Err(DataFusionError::Substrait( - "Cast experssion without output type is not allowed".to_string(), - )), + None => substrait_err!("Cast experssion without output type is not allowed"), }, Some(RexType::WindowFunction(window)) => { let fun = match extensions.get(&window.function_reference) { @@ -889,7 +945,8 @@ pub async fn from_substrait_rex( ), }; let order_by = - from_substrait_sorts(&window.sorts, input_schema, extensions).await?; + from_substrait_sorts(ctx, &window.sorts, input_schema, extensions) + .await?; // Substrait does not encode WindowFrameUnits so we're using a simple logic to determine the units // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row @@ -902,12 +959,14 @@ pub async fn from_substrait_rex( Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction { fun: fun?.unwrap(), args: from_substriat_func_args( + ctx, &window.arguments, input_schema, extensions, ) .await?, partition_by: from_substrait_rex_vec( + ctx, &window.partitions, input_schema, extensions, @@ -921,6 +980,51 @@ pub async fn from_substrait_rex( }, }))) } + Some(RexType::Subquery(subquery)) => match &subquery.as_ref().subquery_type { + Some(subquery_type) => match subquery_type { + SubqueryType::InPredicate(in_predicate) => { + if in_predicate.needles.len() != 1 { + Err(DataFusionError::Substrait( + "InPredicate Subquery type must have exactly one Needle expression" + .to_string(), + )) + } else { + let needle_expr = &in_predicate.needles[0]; + let haystack_expr = &in_predicate.haystack; + if let Some(haystack_expr) = haystack_expr { + let haystack_expr = + from_substrait_rel(ctx, haystack_expr, extensions) + .await?; + let outer_refs = haystack_expr.all_out_ref_exprs(); + Ok(Arc::new(Expr::InSubquery(InSubquery { + expr: Box::new( + from_substrait_rex( + ctx, + needle_expr, + input_schema, + extensions, + ) + .await? + .as_ref() + .clone(), + ), + subquery: Subquery { + subquery: Arc::new(haystack_expr), + outer_ref_columns: outer_refs, + }, + negated: false, + }))) + } else { + substrait_err!("InPredicate Subquery type must have a Haystack expression") + } + } + } + _ => substrait_err!("Subquery type not implemented"), + }, + None => { + substrait_err!("Subquery experssion without SubqueryType is not allowed") + } + }, _ => not_impl_err!("unsupported rex_type"), } } @@ -1003,9 +1107,7 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result { r#type::Kind::List(list) => { let inner_type = from_substrait_type(list.r#type.as_ref().ok_or_else(|| { - DataFusionError::Substrait( - "List type must have inner type".to_string(), - ) + substrait_datafusion_err!("List type must have inner type") })?)?; let field = Arc::new(Field::new("list_item", inner_type, true)); match list.type_variation_reference { @@ -1057,9 +1159,7 @@ fn from_substrait_bound( } } }, - None => Err(DataFusionError::Substrait( - "WindowFunction missing Substrait Bound kind".to_string(), - )), + None => substrait_err!("WindowFunction missing Substrait Bound kind"), }, None => { if is_lower { @@ -1078,36 +1178,28 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { DEFAULT_TYPE_REF => ScalarValue::Int8(Some(*n as i8)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt8(Some(*n as u8)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I16(n)) => match lit.type_variation_reference { DEFAULT_TYPE_REF => ScalarValue::Int16(Some(*n as i16)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt16(Some(*n as u16)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I32(n)) => match lit.type_variation_reference { DEFAULT_TYPE_REF => ScalarValue::Int32(Some(*n)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt32(Some(*n as u32)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I64(n)) => match lit.type_variation_reference { DEFAULT_TYPE_REF => ScalarValue::Int64(Some(*n)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt64(Some(*n as u64)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), @@ -1118,9 +1210,7 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { TIMESTAMP_MICRO_TYPE_REF => ScalarValue::TimestampMicrosecond(Some(*t), None), TIMESTAMP_NANO_TYPE_REF => ScalarValue::TimestampNanosecond(Some(*t), None), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), @@ -1128,38 +1218,30 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Utf8(Some(s.clone())), LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeUtf8(Some(s.clone())), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Binary(b)) => match lit.type_variation_reference { DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Binary(Some(b.clone())), LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeBinary(Some(b.clone())), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::FixedBinary(b)) => { ScalarValue::FixedSizeBinary(b.len() as _, Some(b.clone())) } Some(LiteralType::Decimal(d)) => { - let value: [u8; 16] = - d.value - .clone() - .try_into() - .or(Err(DataFusionError::Substrait( - "Failed to parse decimal value".to_string(), - )))?; + let value: [u8; 16] = d + .value + .clone() + .try_into() + .or(substrait_err!("Failed to parse decimal value"))?; let p = d.precision.try_into().map_err(|e| { - DataFusionError::Substrait(format!( - "Failed to parse decimal precision: {e}" - )) + substrait_datafusion_err!("Failed to parse decimal precision: {e}") })?; let s = d.scale.try_into().map_err(|e| { - DataFusionError::Substrait(format!("Failed to parse decimal scale: {e}")) + substrait_datafusion_err!("Failed to parse decimal scale: {e}") })?; ScalarValue::Decimal128( Some(std::primitive::i128::from_le_bytes(value)), @@ -1257,50 +1339,157 @@ fn from_substrait_null(null_type: &Type) -> Result { } } -async fn make_datafusion_like( - case_insensitive: bool, - f: &ScalarFunction, +fn from_substrait_field_reference( + field_ref: &FieldReference, input_schema: &DFSchema, - extensions: &HashMap, -) -> Result> { - let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; - if f.arguments.len() != 3 { - return not_impl_err!("Expect three arguments for `{fn_name}` expr"); +) -> Result { + match &field_ref.reference_type { + Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { + Some(StructField(x)) => match &x.child.as_ref() { + Some(_) => not_impl_err!( + "Direct reference StructField with child is not supported" + ), + None => { + let column = input_schema.field(x.field as usize).qualified_column(); + Ok(Expr::Column(Column { + relation: column.relation, + name: column.name, + })) + } + }, + _ => not_impl_err!( + "Direct reference with types other than StructField is not supported" + ), + }, + _ => not_impl_err!("unsupported field ref type"), } +} - let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let expr = from_substrait_rex(expr_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); - let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); - let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let escape_char_expr = - from_substrait_rex(escape_char_substrait, input_schema, extensions) +/// Build [`Expr`] from its name and required inputs. +struct BuiltinExprBuilder { + expr_name: String, +} + +impl BuiltinExprBuilder { + pub fn try_from_name(name: &str) -> Option { + match name { + "not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true" + | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" + | "is_not_unknown" | "negative" => Some(Self { + expr_name: name.to_string(), + }), + _ => None, + } + } + + pub async fn build( + self, + ctx: &SessionContext, + f: &ScalarFunction, + input_schema: &DFSchema, + extensions: &HashMap, + ) -> Result> { + match self.expr_name.as_str() { + "like" => { + Self::build_like_expr(ctx, false, f, input_schema, extensions).await + } + "ilike" => { + Self::build_like_expr(ctx, true, f, input_schema, extensions).await + } + "not" | "negative" | "is_null" | "is_not_null" | "is_true" | "is_false" + | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { + Self::build_unary_expr(ctx, &self.expr_name, f, input_schema, extensions) + .await + } + _ => { + not_impl_err!("Unsupported builtin expression: {}", self.expr_name) + } + } + } + + async fn build_unary_expr( + ctx: &SessionContext, + fn_name: &str, + f: &ScalarFunction, + input_schema: &DFSchema, + extensions: &HashMap, + ) -> Result> { + if f.arguments.len() != 1 { + return substrait_err!("Expect one argument for {fn_name} expr"); + } + let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { + return substrait_err!("Invalid arguments type for {fn_name} expr"); + }; + let arg = from_substrait_rex(ctx, expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); - let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else { - return Err(DataFusionError::Substrait(format!( - "Expect Utf8 literal for escape char, but found {escape_char_expr:?}", - ))); - }; + let arg = Box::new(arg); - Ok(Arc::new(Expr::Like(Like { - negated: false, - expr: Box::new(expr), - pattern: Box::new(pattern), - escape_char: escape_char.map(|c| c.chars().next().unwrap()), - case_insensitive, - }))) + let expr = match fn_name { + "not" => Expr::Not(arg), + "negative" => Expr::Negative(arg), + "is_null" => Expr::IsNull(arg), + "is_not_null" => Expr::IsNotNull(arg), + "is_true" => Expr::IsTrue(arg), + "is_false" => Expr::IsFalse(arg), + "is_not_true" => Expr::IsNotTrue(arg), + "is_not_false" => Expr::IsNotFalse(arg), + "is_unknown" => Expr::IsUnknown(arg), + "is_not_unknown" => Expr::IsNotUnknown(arg), + _ => return not_impl_err!("Unsupported builtin expression: {}", fn_name), + }; + + Ok(Arc::new(expr)) + } + + async fn build_like_expr( + ctx: &SessionContext, + case_insensitive: bool, + f: &ScalarFunction, + input_schema: &DFSchema, + extensions: &HashMap, + ) -> Result> { + let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; + if f.arguments.len() != 3 { + return substrait_err!("Expect three arguments for `{fn_name}` expr"); + } + + let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let expr = from_substrait_rex(ctx, expr_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let pattern = + from_substrait_rex(ctx, pattern_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let escape_char_expr = + from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else { + return substrait_err!( + "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" + ); + }; + + Ok(Arc::new(Expr::Like(Like { + negated: false, + expr: Box::new(expr), + pattern: Box::new(pattern), + escape_char: escape_char.map(|c| c.chars().next().unwrap()), + case_insensitive, + }))) + } } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index e17b022f3b53..ab0e8c860858 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -19,7 +19,9 @@ use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; -use datafusion::logical_expr::{Like, WindowFrameUnits}; +use datafusion::logical_expr::{ + CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits, +}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, @@ -28,18 +30,21 @@ use datafusion::{ scalar::ScalarValue, }; -use datafusion::common::DFSchemaRef; use datafusion::common::{exec_err, internal_err, not_impl_err}; +use datafusion::common::{substrait_err, DFSchemaRef}; #[allow(unused_imports)] use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ - Alias, BinaryExpr, Case, Cast, GroupingSet, InList, - ScalarFunction as DFScalarFunction, Sort, WindowFunction, + AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, + InSubquery, ScalarFunctionDefinition, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; +use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; +use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; +use substrait::proto::{CrossRel, ExchangeRel}; use substrait::{ proto::{ aggregate_function::AggregationInvocation, @@ -54,7 +59,8 @@ use substrait::{ window_function::bound::Kind as BoundKind, window_function::Bound, FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, - ScalarFunction, SingularOrList, WindowFunction as SubstraitWindowFunction, + ScalarFunction, SingularOrList, Subquery, + WindowFunction as SubstraitWindowFunction, }, extensions::{ self, @@ -163,7 +169,7 @@ pub fn to_substrait_rel( let expressions = p .expr .iter() - .map(|e| to_substrait_rex(e, p.input.schema(), 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extension_info)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { @@ -177,6 +183,7 @@ pub fn to_substrait_rel( LogicalPlan::Filter(filter) => { let input = to_substrait_rel(filter.input.as_ref(), ctx, extension_info)?; let filter_expr = to_substrait_rex( + ctx, &filter.predicate, filter.input.schema(), 0, @@ -193,7 +200,8 @@ pub fn to_substrait_rel( } LogicalPlan::Limit(limit) => { let input = to_substrait_rel(limit.input.as_ref(), ctx, extension_info)?; - let limit_fetch = limit.fetch.unwrap_or(0); + // Since protobuf can't directly distinguish `None` vs `0` encode `None` as `MAX` + let limit_fetch = limit.fetch.unwrap_or(usize::MAX); Ok(Box::new(Rel { rel_type: Some(RelType::Fetch(Box::new(FetchRel { common: None, @@ -209,7 +217,9 @@ pub fn to_substrait_rel( let sort_fields = sort .expr .iter() - .map(|e| substrait_sort_field(e, sort.input.schema(), extension_info)) + .map(|e| { + substrait_sort_field(ctx, e, sort.input.schema(), extension_info) + }) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { @@ -223,6 +233,7 @@ pub fn to_substrait_rel( LogicalPlan::Aggregate(agg) => { let input = to_substrait_rel(agg.input.as_ref(), ctx, extension_info)?; let groupings = to_substrait_groupings( + ctx, &agg.group_expr, agg.input.schema(), extension_info, @@ -230,7 +241,9 @@ pub fn to_substrait_rel( let measures = agg .aggr_expr .iter() - .map(|e| to_substrait_agg_measure(e, agg.input.schema(), extension_info)) + .map(|e| { + to_substrait_agg_measure(ctx, e, agg.input.schema(), extension_info) + }) .collect::>>()?; Ok(Box::new(Rel { @@ -243,11 +256,11 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Distinct(distinct) => { + LogicalPlan::Distinct(Distinct::All(plan)) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = to_substrait_rel(distinct.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(plan.as_ref(), ctx, extension_info)?; // Get grouping keys from the input relation's number of output fields - let grouping = (0..distinct.input.schema().fields().len()) + let grouping = (0..plan.schema().fields().len()) .map(substrait_field_ref) .collect::>>()?; @@ -277,14 +290,16 @@ pub fn to_substrait_rel( // parse filter if exists let in_join_schema = join.left.schema().join(join.right.schema())?; let join_filter = match &join.filter { - Some(filter) => Some(Box::new(to_substrait_rex( + Some(filter) => Some(to_substrait_rex( + ctx, filter, &Arc::new(in_join_schema), 0, extension_info, - )?)), + )?), None => None, }; + // map the left and right columns to binary expressions in the form `l = r` // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` let eq_op = if join.null_equals_null { @@ -292,15 +307,32 @@ pub fn to_substrait_rel( } else { Operator::Eq }; - - let join_expr = to_substrait_join_expr( + let join_on = to_substrait_join_expr( + ctx, &join.on, eq_op, join.left.schema(), join.right.schema(), extension_info, - )? - .map(Box::new); + )?; + + // create conjunction between `join_on` and `join_filter` to embed all join conditions, + // whether equal or non-equal in a single expression + let join_expr = match &join_on { + Some(on_expr) => match &join_filter { + Some(filter) => Some(Box::new(make_binary_op_scalar_func( + on_expr, + filter, + Operator::And, + extension_info, + ))), + None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist + }, + None => match &join_filter { + Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist + None => None, + }, + }; Ok(Box::new(Rel { rel_type: Some(RelType::Join(Box::new(JoinRel { @@ -309,7 +341,24 @@ pub fn to_substrait_rel( right: Some(right), r#type: join_type as i32, expression: join_expr, - post_join_filter: join_filter, + post_join_filter: None, + advanced_extension: None, + }))), + })) + } + LogicalPlan::CrossJoin(cross_join) => { + let CrossJoin { + left, + right, + schema: _, + } = cross_join; + let left = to_substrait_rel(left.as_ref(), ctx, extension_info)?; + let right = to_substrait_rel(right.as_ref(), ctx, extension_info)?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Cross(Box::new(CrossRel { + common: None, + left: Some(left), + right: Some(right), advanced_extension: None, }))), })) @@ -362,6 +411,7 @@ pub fn to_substrait_rel( let mut window_exprs = vec![]; for expr in &window.window_expr { window_exprs.push(to_substrait_rex( + ctx, expr, window.input.schema(), 0, @@ -374,6 +424,53 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Project(project_rel)), })) } + LogicalPlan::Repartition(repartition) => { + let input = + to_substrait_rel(repartition.input.as_ref(), ctx, extension_info)?; + let partition_count = match repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(num) => num, + Partitioning::Hash(_, num) => num, + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let exchange_kind = match &repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(_) => { + ExchangeKind::RoundRobin(RoundRobin::default()) + } + Partitioning::Hash(exprs, _) => { + let fields = exprs + .iter() + .map(|e| { + try_to_substrait_field_reference( + e, + repartition.input.schema(), + ) + }) + .collect::>>()?; + ExchangeKind::ScatterByFields(ScatterFields { fields }) + } + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + let exchange_rel = ExchangeRel { + common: None, + input: Some(input), + exchange_kind: Some(exchange_kind), + advanced_extension: None, + partition_count: partition_count as i32, + targets: vec![], + }; + Ok(Box::new(Rel { + rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), + })) + } LogicalPlan::Extension(extension_plan) => { let extension_bytes = ctx .state() @@ -414,6 +511,7 @@ pub fn to_substrait_rel( } fn to_substrait_join_expr( + ctx: &SessionContext, join_conditions: &Vec<(Expr, Expr)>, eq_op: Operator, left_schema: &DFSchemaRef, @@ -427,9 +525,10 @@ fn to_substrait_join_expr( let mut exprs: Vec = vec![]; for (left, right) in join_conditions { // Parse left - let l = to_substrait_rex(left, left_schema, 0, extension_info)?; + let l = to_substrait_rex(ctx, left, left_schema, 0, extension_info)?; // Parse right let r = to_substrait_rex( + ctx, right, right_schema, left_schema.fields().len(), // offset to return the correct index @@ -490,6 +589,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { } pub fn parse_flat_grouping_exprs( + ctx: &SessionContext, exprs: &[Expr], schema: &DFSchemaRef, extension_info: &mut ( @@ -499,7 +599,7 @@ pub fn parse_flat_grouping_exprs( ) -> Result { let grouping_expressions = exprs .iter() - .map(|e| to_substrait_rex(e, schema, 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, 0, extension_info)) .collect::>>()?; Ok(Grouping { grouping_expressions, @@ -507,7 +607,8 @@ pub fn parse_flat_grouping_exprs( } pub fn to_substrait_groupings( - exprs: &Vec, + ctx: &SessionContext, + exprs: &[Expr], schema: &DFSchemaRef, extension_info: &mut ( Vec, @@ -522,7 +623,9 @@ pub fn to_substrait_groupings( )), GroupingSet::GroupingSets(sets) => Ok(sets .iter() - .map(|set| parse_flat_grouping_exprs(set, schema, extension_info)) + .map(|set| { + parse_flat_grouping_exprs(ctx, set, schema, extension_info) + }) .collect::>>()?), GroupingSet::Rollup(set) => { let mut sets: Vec> = vec![vec![]]; @@ -532,17 +635,21 @@ pub fn to_substrait_groupings( Ok(sets .iter() .rev() - .map(|set| parse_flat_grouping_exprs(set, schema, extension_info)) + .map(|set| { + parse_flat_grouping_exprs(ctx, set, schema, extension_info) + }) .collect::>>()?) } }, _ => Ok(vec![parse_flat_grouping_exprs( + ctx, exprs, schema, extension_info, )?]), }, _ => Ok(vec![parse_flat_grouping_exprs( + ctx, exprs, schema, extension_info, @@ -552,6 +659,7 @@ pub fn to_substrait_groupings( #[allow(deprecated)] pub fn to_substrait_agg_measure( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, extension_info: &mut ( @@ -560,40 +668,75 @@ pub fn to_substrait_agg_measure( ), ) -> Result { match expr { - Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by }) => { - let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); - } - let function_name = fun.to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: match distinct { - true => AggregationInvocation::Distinct as i32, - false => AggregationInvocation::All as i32, - }, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), - None => None + Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by }) => { + match func_def { + AggregateFunctionDefinition::BuiltIn (fun) => { + let sorts = if let Some(order_by) = order_by { + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); + } + let function_anchor = _register_function(fun.to_string(), extension_info); + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: match distinct { + true => AggregationInvocation::Distinct as i32, + false => AggregationInvocation::All as i32, + }, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), + None => None + } + }) } - }) + AggregateFunctionDefinition::UDF(fun) => { + let sorts = if let Some(order_by) = order_by { + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); + } + let function_anchor = _register_function(fun.name().to_string(), extension_info); + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: AggregationInvocation::All as i32, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), + None => None + } + }) + } + AggregateFunctionDefinition::Name(name) => { + internal_err!("AggregateFunctionDefinition::Name({:?}) should be resolved during `AnalyzerRule`", name) + } + } + } Expr::Alias(Alias{expr,..})=> { - to_substrait_agg_measure(expr, schema, extension_info) + to_substrait_agg_measure(ctx, expr, schema, extension_info) } _ => internal_err!( "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", @@ -605,6 +748,7 @@ pub fn to_substrait_agg_measure( /// Converts sort expression to corresponding substrait `SortField` fn to_substrait_sort_field( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, extension_info: &mut ( @@ -622,6 +766,7 @@ fn to_substrait_sort_field( }; Ok(SortField { expr: Some(to_substrait_rex( + ctx, sort.expr.deref(), schema, 0, @@ -685,8 +830,8 @@ pub fn make_binary_op_scalar_func( HashMap, ), ) -> Expression { - let function_name = operator_to_name(op).to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + let function_anchor = + _register_function(operator_to_name(op).to_string(), extension_info); Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -730,6 +875,7 @@ pub fn make_binary_op_scalar_func( /// * `extension_info` - Substrait extension info. Contains registered function information #[allow(deprecated)] pub fn to_substrait_rex( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, @@ -746,10 +892,10 @@ pub fn to_substrait_rex( }) => { let substrait_list = list .iter() - .map(|x| to_substrait_rex(x, schema, col_ref_offset, extension_info)) + .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extension_info)) .collect::>>()?; let substrait_expr = - to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { @@ -777,11 +923,12 @@ pub fn to_substrait_rex( Ok(substrait_or_list) } } - Expr::ScalarFunction(DFScalarFunction { fun, args }) => { + Expr::ScalarFunction(fun) => { let mut arguments: Vec = vec![]; - for arg in args { + for arg in &fun.args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( + ctx, arg, schema, col_ref_offset, @@ -789,8 +936,14 @@ pub fn to_substrait_rex( )?)), }); } - let function_name = fun.to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + + // function should be resolved during `AnalyzerRule` + if let ScalarFunctionDefinition::Name(_) = fun.func_def { + return internal_err!("Function `Expr` with name should be resolved."); + } + + let function_anchor = + _register_function(fun.name().to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -810,11 +963,11 @@ pub fn to_substrait_rex( if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = - to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let substrait_low = - to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; let substrait_high = - to_substrait_rex(high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, @@ -838,11 +991,11 @@ pub fn to_substrait_rex( } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) let substrait_expr = - to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let substrait_low = - to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; let substrait_high = - to_substrait_rex(high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_low, @@ -870,8 +1023,8 @@ pub fn to_substrait_rex( substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(left, schema, col_ref_offset, extension_info)?; - let r = to_substrait_rex(right, schema, col_ref_offset, extension_info)?; + let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extension_info)?; + let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extension_info)?; Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) } @@ -886,6 +1039,7 @@ pub fn to_substrait_rex( // Base expression exists ifs.push(IfClause { r#if: Some(to_substrait_rex( + ctx, e, schema, col_ref_offset, @@ -898,12 +1052,14 @@ pub fn to_substrait_rex( for (r#if, then) in when_then_expr { ifs.push(IfClause { r#if: Some(to_substrait_rex( + ctx, r#if, schema, col_ref_offset, extension_info, )?), then: Some(to_substrait_rex( + ctx, then, schema, col_ref_offset, @@ -915,6 +1071,7 @@ pub fn to_substrait_rex( // Parse outer `else` let r#else: Option> = match else_expr { Some(e) => Some(Box::new(to_substrait_rex( + ctx, e, schema, col_ref_offset, @@ -933,6 +1090,7 @@ pub fn to_substrait_rex( substrait::proto::expression::Cast { r#type: Some(to_substrait_type(data_type)?), input: Some(Box::new(to_substrait_rex( + ctx, expr, schema, col_ref_offset, @@ -945,7 +1103,7 @@ pub fn to_substrait_rex( } Expr::Literal(value) => to_substrait_literal(value), Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(expr, schema, col_ref_offset, extension_info) + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info) } Expr::WindowFunction(WindowFunction { fun, @@ -955,13 +1113,13 @@ pub fn to_substrait_rex( window_frame, }) => { // function reference - let function_name = fun.to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + let function_anchor = _register_function(fun.to_string(), extension_info); // arguments let mut arguments: Vec = vec![]; for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( + ctx, arg, schema, col_ref_offset, @@ -972,12 +1130,12 @@ pub fn to_substrait_rex( // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(e, schema, col_ref_offset, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extension_info)) .collect::>>()?; // order by expressions let order_by = order_by .iter() - .map(|e| substrait_sort_field(e, schema, extension_info)) + .map(|e| substrait_sort_field(ctx, e, schema, extension_info)) .collect::>>()?; // window frame let bounds = to_substrait_bounds(window_frame)?; @@ -998,6 +1156,7 @@ pub fn to_substrait_rex( escape_char, case_insensitive, }) => make_substrait_like_expr( + ctx, *case_insensitive, *negated, expr, @@ -1007,7 +1166,131 @@ pub fn to_substrait_rex( col_ref_offset, extension_info, ), - _ => not_impl_err!("Unsupported expression: {expr:?}"), + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => { + let substrait_expr = + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + + let subquery_plan = + to_substrait_rel(subquery.subquery.as_ref(), ctx, extension_info)?; + + let substrait_subquery = Expression { + rex_type: Some(RexType::Subquery(Box::new(Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::InPredicate( + Box::new(InPredicate { + needles: (vec![substrait_expr]), + haystack: Some(subquery_plan), + }), + ), + ), + }))), + }; + if *negated { + let function_anchor = + _register_function("not".to_string(), extension_info); + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_subquery)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_subquery) + } + } + Expr::Not(arg) => to_substrait_unary_scalar_fn( + ctx, + "not", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNull(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_null", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_not_null", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_true", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_false", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_unknown", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_not_true", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_not_false", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( + ctx, + "is_not_unknown", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::Negative(arg) => to_substrait_unary_scalar_fn( + ctx, + "negative", + arg, + schema, + col_ref_offset, + extension_info, + ), + _ => { + not_impl_err!("Unsupported expression: {expr:?}") + } } } @@ -1223,6 +1506,7 @@ fn make_substrait_window_function( #[allow(deprecated)] #[allow(clippy::too_many_arguments)] fn make_substrait_like_expr( + ctx: &SessionContext, ignore_case: bool, negated: bool, expr: &Expr, @@ -1240,8 +1524,8 @@ fn make_substrait_like_expr( } else { _register_function("like".to_string(), extension_info) }; - let expr = to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; - let pattern = to_substrait_rex(pattern, schema, col_ref_offset, extension_info)?; + let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; let escape_char = to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c| c.to_string())))?; let arguments = vec![ @@ -1469,6 +1753,35 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { }) } +/// Util to generate substrait [RexType::ScalarFunction] with one argument +fn to_substrait_unary_scalar_fn( + ctx: &SessionContext, + fn_name: &str, + arg: &Expr, + schema: &DFSchemaRef, + col_ref_offset: usize, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { + let function_anchor = _register_function(fn_name.to_string(), extension_info); + let substrait_expr = + to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?; + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_expr)), + }], + output_type: None, + options: vec![], + ..Default::default() + })), + }) +} + fn try_to_substrait_null(v: &ScalarValue) -> Result { let default_nullability = r#type::Nullability::Nullable as i32; match v { @@ -1629,7 +1942,33 @@ fn try_to_substrait_null(v: &ScalarValue) -> Result { } } +/// Try to convert an [Expr] to a [FieldReference]. +/// Returns `Err` if the [Expr] is not a [Expr::Column]. +fn try_to_substrait_field_reference( + expr: &Expr, + schema: &DFSchemaRef, +) -> Result { + match expr { + Expr::Column(col) => { + let index = schema.index_of_column(col)?; + Ok(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField( + Box::new(reference_segment::StructField { + field: index as i32, + child: None, + }), + )), + })), + root_type: None, + }) + } + _ => substrait_err!("Expect a `Column` expr, but found {expr:?}"), + } +} + fn substrait_sort_field( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, extension_info: &mut ( @@ -1643,7 +1982,7 @@ fn substrait_sort_field( asc, nulls_first, }) => { - let e = to_substrait_rex(expr, schema, 0, extension_info)?; + let e = to_substrait_rex(ctx, expr, schema, 0, extension_info)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, diff --git a/datafusion/substrait/src/physical_plan/consumer.rs b/datafusion/substrait/src/physical_plan/consumer.rs index 8d25626a3bfe..3098dc386e6a 100644 --- a/datafusion/substrait/src/physical_plan/consumer.rs +++ b/datafusion/substrait/src/physical_plan/consumer.rs @@ -15,19 +15,21 @@ // specific language governing permissions and limitations // under the License. -use async_recursion::async_recursion; -use chrono::DateTime; +use std::collections::HashMap; +use std::sync::Arc; + use datafusion::arrow::datatypes::Schema; use datafusion::common::not_impl_err; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; use datafusion::error::{DataFusionError, Result}; -use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::{ExecutionPlan, Statistics}; use datafusion::prelude::SessionContext; + +use async_recursion::async_recursion; +use chrono::DateTime; use object_store::ObjectMeta; -use std::collections::HashMap; -use std::sync::Arc; use substrait::proto::read_rel::local_files::file_or_files::PathType; use substrait::proto::{ expression::MaskExpression, read_rel::ReadType, rel::RelType, Rel, @@ -36,7 +38,7 @@ use substrait::proto::{ /// Convert Substrait Rel to DataFusion ExecutionPlan #[async_recursion] pub async fn from_substrait_rel( - _ctx: &mut SessionContext, + _ctx: &SessionContext, rel: &Rel, _extensions: &HashMap, ) -> Result> { @@ -87,6 +89,7 @@ pub async fn from_substrait_rel( location: path.into(), size, e_tag: None, + version: None, }, partition_values: vec![], range: None, @@ -104,12 +107,11 @@ pub async fn from_substrait_rel( object_store_url: ObjectStoreUrl::local_filesystem(), file_schema: Arc::new(Schema::empty()), file_groups, - statistics: Default::default(), + statistics: Statistics::new_unknown(&Schema::empty()), projection: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; if let Some(MaskExpression { select, .. }) = &read.projection { diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index f4d74ae42681..d7327caee43d 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +use datafusion::arrow::array::ArrayRef; +use datafusion::physical_plan::Accumulator; +use datafusion::scalar::ScalarValue; use datafusion_substrait::logical_plan::{ consumer::from_substrait_plan, producer::to_substrait_plan, }; @@ -23,15 +26,20 @@ use std::hash::Hash; use std::sync::Arc; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use datafusion::common::{DFSchema, DFSchemaRef}; -use datafusion::error::Result; +use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef}; +use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionState; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; -use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode}; +use datafusion::logical_expr::{ + Extension, LogicalPlan, Repartition, UserDefinedLogicalNode, Volatility, +}; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; + use substrait::proto::extensions::simple_extension_declaration::MappingType; +use substrait::proto::rel::RelType; +use substrait::proto::{plan_rel, Plan, Rel}; struct MockSerializerRegistry; @@ -188,6 +196,11 @@ async fn select_with_limit() -> Result<()> { roundtrip_fill_na("SELECT * FROM data LIMIT 100").await } +#[tokio::test] +async fn select_without_limit() -> Result<()> { + roundtrip_fill_na("SELECT * FROM data OFFSET 10").await +} + #[tokio::test] async fn select_with_limit_offset() -> Result<()> { roundtrip("SELECT * FROM data LIMIT 200 OFFSET 10").await @@ -306,6 +319,16 @@ async fn simple_scalar_function_substr() -> Result<()> { roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await } +#[tokio::test] +async fn simple_scalar_function_is_null() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a IS NULL").await +} + +#[tokio::test] +async fn simple_scalar_function_is_not_null() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a IS NOT NULL").await +} + #[tokio::test] async fn case_without_base_expression() -> Result<()> { roundtrip("SELECT (CASE WHEN a >= 0 THEN 'positive' ELSE 'negative' END) FROM data") @@ -371,6 +394,29 @@ async fn roundtrip_inlist_4() -> Result<()> { roundtrip("SELECT * FROM data WHERE f NOT IN ('a', 'b', 'c', 'd')").await } +#[tokio::test] +async fn roundtrip_inlist_5() -> Result<()> { + // on roundtrip there is an additional projection during TableScan which includes all column of the table, + // using assert_expected_plan here as a workaround + assert_expected_plan( + "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))", + "Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()\ + \n Subquery:\ + \n Projection: data2.a\ + \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ + \n TableScan: data2 projection=[a, b, c, d, e, f]\ + \n TableScan: data projection=[a, f], partial_filters=[data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()]\ + \n Subquery:\ + \n Projection: data2.a\ + \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ + \n TableScan: data2 projection=[a, b, c, d, e, f]").await +} + +#[tokio::test] +async fn roundtrip_cross_join() -> Result<()> { + roundtrip("SELECT * FROM data CROSS JOIN data2").await +} + #[tokio::test] async fn roundtrip_inner_join() -> Result<()> { roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await @@ -378,12 +424,15 @@ async fn roundtrip_inner_join() -> Result<()> { #[tokio::test] async fn roundtrip_non_equi_inner_join() -> Result<()> { - roundtrip("SELECT data.a FROM data JOIN data2 ON data.a <> data2.a").await + roundtrip_verify_post_join_filter( + "SELECT data.a FROM data JOIN data2 ON data.a <> data2.a", + ) + .await } #[tokio::test] async fn roundtrip_non_equi_join() -> Result<()> { - roundtrip( + roundtrip_verify_post_join_filter( "SELECT data.a FROM data, data2 WHERE data.a = data2.a AND data.e > data2.a", ) .await @@ -452,6 +501,46 @@ async fn roundtrip_ilike() -> Result<()> { roundtrip("SELECT f FROM data WHERE f ILIKE 'a%b'").await } +#[tokio::test] +async fn roundtrip_not() -> Result<()> { + roundtrip("SELECT * FROM data WHERE NOT d").await +} + +#[tokio::test] +async fn roundtrip_negative() -> Result<()> { + roundtrip("SELECT * FROM data WHERE -a = 1").await +} + +#[tokio::test] +async fn roundtrip_is_true() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS TRUE").await +} + +#[tokio::test] +async fn roundtrip_is_false() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS FALSE").await +} + +#[tokio::test] +async fn roundtrip_is_not_true() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS NOT TRUE").await +} + +#[tokio::test] +async fn roundtrip_is_not_false() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS NOT FALSE").await +} + +#[tokio::test] +async fn roundtrip_is_unknown() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS UNKNOWN").await +} + +#[tokio::test] +async fn roundtrip_is_not_unknown() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS NOT UNKNOWN").await +} + #[tokio::test] async fn roundtrip_union() -> Result<()> { roundtrip("SELECT a, e FROM data UNION SELECT a, e FROM data").await @@ -475,10 +564,11 @@ async fn simple_intersect() -> Result<()> { assert_expected_plan( "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);", "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n LeftSemi Join: data.a = data2.a\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n TableScan: data2 projection=[a]", + \n Projection: \ + \n LeftSemi Join: data.a = data2.a\ + \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ + \n TableScan: data projection=[a]\ + \n TableScan: data2 projection=[a]", ) .await } @@ -488,10 +578,11 @@ async fn simple_intersect_table_reuse() -> Result<()> { assert_expected_plan( "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);", "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n LeftSemi Join: data.a = data.a\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n TableScan: data projection=[a]", + \n Projection: \ + \n LeftSemi Join: data.a = data.a\ + \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ + \n TableScan: data projection=[a]\ + \n TableScan: data projection=[a]", ) .await } @@ -595,7 +686,7 @@ async fn new_test_grammar() -> Result<()> { #[tokio::test] async fn extension_logical_plan() -> Result<()> { - let mut ctx = create_context().await?; + let ctx = create_context().await?; let validation_bytes = "MockUserDefinedLogicalPlan".as_bytes().to_vec(); let ext_plan = LogicalPlan::Extension(Extension { node: Arc::new(MockUserDefinedLogicalPlan { @@ -606,7 +697,7 @@ async fn extension_logical_plan() -> Result<()> { }); let proto = to_substrait_plan(&ext_plan, &ctx)?; - let plan2 = from_substrait_plan(&mut ctx, &proto).await?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; let plan1str = format!("{ext_plan:?}"); let plan2str = format!("{plan2:?}"); @@ -615,12 +706,181 @@ async fn extension_logical_plan() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_aggregate_udf() -> Result<()> { + #[derive(Debug)] + struct Dummy {} + + impl Accumulator for Dummy { + fn state(&self) -> datafusion::error::Result> { + Ok(vec![]) + } + + fn update_batch( + &mut self, + _values: &[ArrayRef], + ) -> datafusion::error::Result<()> { + Ok(()) + } + + fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + Ok(()) + } + + fn evaluate(&self) -> datafusion::error::Result { + Ok(ScalarValue::Float64(None)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + } + + let dummy_agg = create_udaf( + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "dummy_agg", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + vec![DataType::Int64], + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Int64), + Volatility::Immutable, + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new(|_| Ok(Box::new(Dummy {}))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), + ); + + let ctx = create_context().await?; + ctx.register_udaf(dummy_agg); + + roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await +} + +#[tokio::test] +async fn roundtrip_repartition_roundrobin() -> Result<()> { + let ctx = create_context().await?; + let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; + let plan = LogicalPlan::Repartition(Repartition { + input: Arc::new(scan_plan), + partitioning_scheme: Partitioning::RoundRobinBatch(8), + }); + + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + assert_eq!(format!("{plan:?}"), format!("{plan2:?}")); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_repartition_hash() -> Result<()> { + let ctx = create_context().await?; + let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; + let plan = LogicalPlan::Repartition(Repartition { + input: Arc::new(scan_plan), + partitioning_scheme: Partitioning::Hash(vec![col("data.a")], 8), + }); + + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + assert_eq!(format!("{plan:?}"), format!("{plan2:?}")); + Ok(()) +} + +fn check_post_join_filters(rel: &Rel) -> Result<()> { + // search for target_rel and field value in proto + match &rel.rel_type { + Some(RelType::Join(join)) => { + // check if join filter is None + if join.post_join_filter.is_some() { + plan_err!( + "DataFusion generated Susbtrait plan cannot have post_join_filter in JoinRel" + ) + } else { + // recursively check JoinRels + match check_post_join_filters(join.left.as_ref().unwrap().as_ref()) { + Err(e) => Err(e), + Ok(_) => { + check_post_join_filters(join.right.as_ref().unwrap().as_ref()) + } + } + } + } + Some(RelType::Project(p)) => { + check_post_join_filters(p.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Filter(filter)) => { + check_post_join_filters(filter.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Fetch(fetch)) => { + check_post_join_filters(fetch.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Sort(sort)) => { + check_post_join_filters(sort.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Aggregate(agg)) => { + check_post_join_filters(agg.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Set(set)) => { + for input in &set.inputs { + match check_post_join_filters(input) { + Err(e) => return Err(e), + Ok(_) => continue, + } + } + Ok(()) + } + Some(RelType::ExtensionSingle(ext)) => { + check_post_join_filters(ext.input.as_ref().unwrap().as_ref()) + } + Some(RelType::ExtensionMulti(ext)) => { + for input in &ext.inputs { + match check_post_join_filters(input) { + Err(e) => return Err(e), + Ok(_) => continue, + } + } + Ok(()) + } + Some(RelType::ExtensionLeaf(_)) | Some(RelType::Read(_)) => Ok(()), + _ => not_impl_err!( + "Unsupported RelType: {:?} in post join filter check", + rel.rel_type + ), + } +} + +async fn verify_post_join_filter_value(proto: Box) -> Result<()> { + for relation in &proto.relations { + match relation.rel_type.as_ref() { + Some(rt) => match rt { + plan_rel::RelType::Rel(rel) => match check_post_join_filters(rel) { + Err(e) => return Err(e), + Ok(_) => continue, + }, + plan_rel::RelType::Root(root) => { + match check_post_join_filters(root.input.as_ref().unwrap()) { + Err(e) => return Err(e), + Ok(_) => continue, + } + } + }, + None => return plan_err!("Cannot parse plan relation: None"), + } + } + + Ok(()) +} + async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> { - let mut ctx = create_context().await?; + let ctx = create_context().await?; let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&mut ctx, &proto).await?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; let plan2str = format!("{plan2:?}"); assert_eq!(expected_plan_str, &plan2str); @@ -628,11 +888,11 @@ async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> } async fn roundtrip_fill_na(sql: &str) -> Result<()> { - let mut ctx = create_context().await?; + let ctx = create_context().await?; let df = ctx.sql(sql).await?; let plan1 = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan1, &ctx)?; - let plan2 = from_substrait_plan(&mut ctx, &proto).await?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; // Format plan string and replace all None's with 0 @@ -647,15 +907,15 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { // Since we ignore the SubqueryAlias in the producer, the result should be // the same as producing a Substrait plan from the same query without aliases // sql_with_alias -> substrait -> logical plan = sql_no_alias -> substrait -> logical plan - let mut ctx = create_context().await?; + let ctx = create_context().await?; let df_a = ctx.sql(sql_with_alias).await?; let proto_a = to_substrait_plan(&df_a.into_optimized_plan()?, &ctx)?; - let plan_with_alias = from_substrait_plan(&mut ctx, &proto_a).await?; + let plan_with_alias = from_substrait_plan(&ctx, &proto_a).await?; let df = ctx.sql(sql_no_alias).await?; let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx)?; - let plan = from_substrait_plan(&mut ctx, &proto).await?; + let plan = from_substrait_plan(&ctx, &proto).await?; println!("{plan_with_alias:#?}"); println!("{plan:#?}"); @@ -666,12 +926,11 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { Ok(()) } -async fn roundtrip(sql: &str) -> Result<()> { - let mut ctx = create_context().await?; +async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> { let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&mut ctx, &proto).await?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; println!("{plan:#?}"); @@ -683,12 +942,35 @@ async fn roundtrip(sql: &str) -> Result<()> { Ok(()) } +async fn roundtrip(sql: &str) -> Result<()> { + roundtrip_with_ctx(sql, create_context().await?).await +} + +async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> { + let ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + println!("{plan:#?}"); + println!("{plan2:#?}"); + + let plan1str = format!("{plan:?}"); + let plan2str = format!("{plan2:?}"); + assert_eq!(plan1str, plan2str); + + // verify that the join filters are None + verify_post_join_filter_value(proto).await +} + async fn roundtrip_all_types(sql: &str) -> Result<()> { - let mut ctx = create_all_type_context().await?; + let ctx = create_all_type_context().await?; let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&mut ctx, &proto).await?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; println!("{plan:#?}"); @@ -721,12 +1003,12 @@ async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { } async fn create_context() -> Result { - let state = SessionState::with_config_rt( + let state = SessionState::new_with_config_rt( SessionConfig::default(), Arc::new(RuntimeEnv::default()), ) .with_serializer_registry(Arc::new(MockSerializerRegistry)); - let ctx = SessionContext::with_state(state); + let ctx = SessionContext::new_with_state(state); let mut explicit_options = CsvReadOptions::new(); let schema = Schema::new(vec![ Field::new("a", DataType::Int64, true), diff --git a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs index 25d60471a9cd..e5af3f94cc05 100644 --- a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs @@ -15,16 +15,18 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; +use std::sync::Arc; + use datafusion::arrow::datatypes::Schema; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; use datafusion::error::Result; -use datafusion::physical_plan::{displayable, ExecutionPlan}; +use datafusion::physical_plan::{displayable, ExecutionPlan, Statistics}; use datafusion::prelude::SessionContext; use datafusion_substrait::physical_plan::{consumer, producer}; -use std::collections::HashMap; -use std::sync::Arc; + use substrait::proto::extensions; #[tokio::test] @@ -42,12 +44,11 @@ async fn parquet_exec() -> Result<()> { 123, )], ], - statistics: Default::default(), + statistics: Statistics::new_unknown(&Schema::empty()), projection: None, limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let parquet_exec: Arc = Arc::new(ParquetExec::new(scan_config, None, None)); @@ -60,10 +61,10 @@ async fn parquet_exec() -> Result<()> { let substrait_rel = producer::to_substrait_rel(parquet_exec.as_ref(), &mut extension_info)?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let parquet_exec_roundtrip = - consumer::from_substrait_rel(&mut ctx, substrait_rel.as_ref(), &HashMap::new()) + consumer::from_substrait_rel(&ctx, substrait_rel.as_ref(), &HashMap::new()) .await?; let expected = format!("{}", displayable(parquet_exec.as_ref()).indent(true)); diff --git a/datafusion/substrait/tests/cases/serialize.rs b/datafusion/substrait/tests/cases/serialize.rs index d6dc5d7e58f2..f6736ca22279 100644 --- a/datafusion/substrait/tests/cases/serialize.rs +++ b/datafusion/substrait/tests/cases/serialize.rs @@ -30,7 +30,7 @@ mod tests { #[tokio::test] async fn serialize_simple_select() -> Result<()> { - let mut ctx = create_context().await?; + let ctx = create_context().await?; let path = "tests/simple_select.bin"; let sql = "SELECT a, b FROM data"; // Test reference @@ -42,7 +42,7 @@ mod tests { // Read substrait plan from file let proto = serializer::deserialize(path).await?; // Check plan equality - let plan = from_substrait_plan(&mut ctx, &proto).await?; + let plan = from_substrait_plan(&ctx, &proto).await?; let plan_str_ref = format!("{plan_ref:?}"); let plan_str = format!("{plan:?}"); assert_eq!(plan_str_ref, plan_str); diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index 0aee45a511f3..c5f795d0653a 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -18,9 +18,9 @@ [package] name = "datafusion-wasmtest" description = "Test library to compile datafusion crates to wasm" +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -38,13 +38,13 @@ crate-type = ["cdylib", "rlib",] # code size when deploying. console_error_panic_hook = { version = "0.1.1", optional = true } -datafusion-common = { path = "../common", version = "31.0.0", default-features = false } -datafusion-expr = { path = "../expr" } -datafusion-optimizer = { path = "../optimizer" } -datafusion-physical-expr = { path = "../physical-expr" } -datafusion-sql = { path = "../sql" } +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-optimizer = { workspace = true } +datafusion-physical-expr = { workspace = true } +datafusion-sql = { workspace = true } # getrandom must be compiled with js feature getrandom = { version = "0.2.8", features = ["js"] } -parquet = { version = "47.0.0", default-features = false } +parquet = { workspace = true } wasm-bindgen = "0.2.87" diff --git a/datafusion/wasmtest/README.md b/datafusion/wasmtest/README.md index 5dc7bb2de45d..d26369a18ab9 100644 --- a/datafusion/wasmtest/README.md +++ b/datafusion/wasmtest/README.md @@ -17,9 +17,16 @@ under the License. --> -## wasmtest +# DataFusion wasmtest + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion used to verify that various DataFusion crates compile successfully to the +`wasm32-unknown-unknown` target with wasm-pack. -Library crate to verify that various DataFusion crates compile successfully to the `wasm32-unknown-unknown` target with wasm-pack. +[df]: https://crates.io/crates/datafusion + +## wasmtest Some of DataFusion's downstream projects compile to WASM to run in the browser. Doing so requires special care that certain library dependencies are not included in DataFusion. diff --git a/dev/changelog/32.0.0.md b/dev/changelog/32.0.0.md new file mode 100644 index 000000000000..781fd5001552 --- /dev/null +++ b/dev/changelog/32.0.0.md @@ -0,0 +1,195 @@ + + +## [32.0.0](https://github.com/apache/arrow-datafusion/tree/32.0.0) (2023-10-07) + +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/31.0.0...32.0.0) + +**Breaking changes:** + +- Remove implicit interval type coercion from ScalarValue comparison [#7514](https://github.com/apache/arrow-datafusion/pull/7514) (tustvold) +- Remove get_scan_files and ExecutionPlan::file_scan_config (#7357) [#7487](https://github.com/apache/arrow-datafusion/pull/7487) (tustvold) +- Move `FileCompressionType` out of `common` and into `core` [#7596](https://github.com/apache/arrow-datafusion/pull/7596) (haohuaijin) +- Update arrow 47.0.0 in DataFusion [#7587](https://github.com/apache/arrow-datafusion/pull/7587) (tustvold) +- Rename `bounded_order_preserving_variants` config to `prefer_exising_sort` and update docs [#7723](https://github.com/apache/arrow-datafusion/pull/7723) (alamb) + +**Implemented enhancements:** + +- Parallelize Stateless (CSV/JSON) File Write Serialization [#7452](https://github.com/apache/arrow-datafusion/pull/7452) (devinjdangelo) +- Create a Priority Queue based Aggregation with `limit` [#7192](https://github.com/apache/arrow-datafusion/pull/7192) (avantgardnerio) +- feat: add guarantees to simplification [#7467](https://github.com/apache/arrow-datafusion/pull/7467) (wjones127) +- [Minor]: Produce better plan when group by contains all of the ordering requirements [#7542](https://github.com/apache/arrow-datafusion/pull/7542) (mustafasrepo) +- Make AvroArrowArrayReader possible to scan Avro backed table which contains nested records [#7525](https://github.com/apache/arrow-datafusion/pull/7525) (sarutak) +- feat: Support spilling for hash aggregation [#7400](https://github.com/apache/arrow-datafusion/pull/7400) (kazuyukitanimura) +- Parallelize Parquet Serialization [#7562](https://github.com/apache/arrow-datafusion/pull/7562) (devinjdangelo) +- feat: natively support more data types for the `abs` function. [#7568](https://github.com/apache/arrow-datafusion/pull/7568) (jonahgao) +- feat: Parallel collecting parquet files statistics #7573 [#7595](https://github.com/apache/arrow-datafusion/pull/7595) (hengfeiyang) +- Support hashing List columns [#7616](https://github.com/apache/arrow-datafusion/pull/7616) (jonmmease) +- feat: Better large output display in datafusion-cli with --maxrows option [#7617](https://github.com/apache/arrow-datafusion/pull/7617) (2010YOUY01) +- feat: make parse_float_as_decimal work on negative numbers [#7648](https://github.com/apache/arrow-datafusion/pull/7648) (jonahgao) +- Update Default Parquet Write Compression [#7692](https://github.com/apache/arrow-datafusion/pull/7692) (devinjdangelo) +- Support all the codecs supported by Avro [#7718](https://github.com/apache/arrow-datafusion/pull/7718) (sarutak) +- Optimize "ORDER BY + LIMIT" queries for speed / memory with special TopK operator [#7721](https://github.com/apache/arrow-datafusion/pull/7721) (Dandandan) + +**Fixed bugs:** + +- fix: inconsistent behaviors when dividing floating numbers by zero [#7503](https://github.com/apache/arrow-datafusion/pull/7503) (jonahgao) +- fix: skip EliminateCrossJoin rule if inner join with filter is found [#7529](https://github.com/apache/arrow-datafusion/pull/7529) (epsio-banay) +- fix: check for precision overflow when parsing float as decimal [#7627](https://github.com/apache/arrow-datafusion/pull/7627) (jonahgao) +- fix: substrait limit when fetch is None [#7669](https://github.com/apache/arrow-datafusion/pull/7669) (waynexia) +- fix: coerce text to timestamps with timezones [#7720](https://github.com/apache/arrow-datafusion/pull/7720) (mhilton) +- fix: avro_to_arrow: Handle avro nested nullable struct (union) [#7663](https://github.com/apache/arrow-datafusion/pull/7663) (Samrose-Ahmed) + +**Documentation updates:** + +- Documentation Updates for New Write Related Features [#7520](https://github.com/apache/arrow-datafusion/pull/7520) (devinjdangelo) +- Create 2023 Q4 roadmap [#7551](https://github.com/apache/arrow-datafusion/pull/7551) (graydenshand) +- docs: add section on supports_filters_pushdown [#7680](https://github.com/apache/arrow-datafusion/pull/7680) (tshauck) +- Add LanceDB to the list of Known Users [#7716](https://github.com/apache/arrow-datafusion/pull/7716) (alamb) +- Document crate feature flags [#7713](https://github.com/apache/arrow-datafusion/pull/7713) (alamb) + +**Merged pull requests:** + +- Prepare 31.0.0 release [#7508](https://github.com/apache/arrow-datafusion/pull/7508) (andygrove) +- Minor(proto): Implement `TryFrom<&DFSchema>` for `protobuf::DfSchema` [#7505](https://github.com/apache/arrow-datafusion/pull/7505) (jonahgao) +- fix: inconsistent behaviors when dividing floating numbers by zero [#7503](https://github.com/apache/arrow-datafusion/pull/7503) (jonahgao) +- Parallelize Stateless (CSV/JSON) File Write Serialization [#7452](https://github.com/apache/arrow-datafusion/pull/7452) (devinjdangelo) +- Minor: Remove stray comment markings from encoding error message [#7512](https://github.com/apache/arrow-datafusion/pull/7512) (devinjdangelo) +- Remove implicit interval type coercion from ScalarValue comparison [#7514](https://github.com/apache/arrow-datafusion/pull/7514) (tustvold) +- Minor: deprecate ScalarValue::get_datatype() [#7507](https://github.com/apache/arrow-datafusion/pull/7507) (Weijun-H) +- Propagate error from spawned task reading spills [#7510](https://github.com/apache/arrow-datafusion/pull/7510) (viirya) +- Refactor the EnforceDistribution Rule [#7488](https://github.com/apache/arrow-datafusion/pull/7488) (mustafasrepo) +- Remove get_scan_files and ExecutionPlan::file_scan_config (#7357) [#7487](https://github.com/apache/arrow-datafusion/pull/7487) (tustvold) +- Simplify ScalarValue::distance (#7517) [#7519](https://github.com/apache/arrow-datafusion/pull/7519) (tustvold) +- typo: change `delimeter` to `delimiter` [#7521](https://github.com/apache/arrow-datafusion/pull/7521) (Weijun-H) +- Fix some simplification rules for floating-point arithmetic operations [#7515](https://github.com/apache/arrow-datafusion/pull/7515) (jonahgao) +- Documentation Updates for New Write Related Features [#7520](https://github.com/apache/arrow-datafusion/pull/7520) (devinjdangelo) +- [MINOR]: Move tests from repartition to enforce_distribution file [#7539](https://github.com/apache/arrow-datafusion/pull/7539) (mustafasrepo) +- Update the async-trait crate to resolve clippy bug [#7541](https://github.com/apache/arrow-datafusion/pull/7541) (metesynnada) +- Fix flaky `test_sort_fetch_memory_calculation` test [#7534](https://github.com/apache/arrow-datafusion/pull/7534) (viirya) +- Move common code to utils [#7545](https://github.com/apache/arrow-datafusion/pull/7545) (mustafasrepo) +- Minor: Add comments and clearer constructors to `Interval` [#7526](https://github.com/apache/arrow-datafusion/pull/7526) (alamb) +- fix: skip EliminateCrossJoin rule if inner join with filter is found [#7529](https://github.com/apache/arrow-datafusion/pull/7529) (epsio-banay) +- Create a Priority Queue based Aggregation with `limit` [#7192](https://github.com/apache/arrow-datafusion/pull/7192) (avantgardnerio) +- feat: add guarantees to simplification [#7467](https://github.com/apache/arrow-datafusion/pull/7467) (wjones127) +- [Minor]: Produce better plan when group by contains all of the ordering requirements [#7542](https://github.com/apache/arrow-datafusion/pull/7542) (mustafasrepo) +- Minor: beautify interval display [#7554](https://github.com/apache/arrow-datafusion/pull/7554) (Weijun-H) +- replace ptree with termtree [#7560](https://github.com/apache/arrow-datafusion/pull/7560) (avantgardnerio) +- Make AvroArrowArrayReader possible to scan Avro backed table which contains nested records [#7525](https://github.com/apache/arrow-datafusion/pull/7525) (sarutak) +- Fix a race condition issue on reading spilled file [#7538](https://github.com/apache/arrow-datafusion/pull/7538) (sarutak) +- [MINOR]: Add is single method [#7558](https://github.com/apache/arrow-datafusion/pull/7558) (mustafasrepo) +- Fix `describe
` to work without SessionContext [#7441](https://github.com/apache/arrow-datafusion/pull/7441) (alamb) +- Make the tests in SHJ faster [#7543](https://github.com/apache/arrow-datafusion/pull/7543) (metesynnada) +- feat: Support spilling for hash aggregation [#7400](https://github.com/apache/arrow-datafusion/pull/7400) (kazuyukitanimura) +- Make backtrace as a cargo feature [#7527](https://github.com/apache/arrow-datafusion/pull/7527) (comphead) +- Minor: Fix `clippy` by switching to `timestamp_nanos_opt` instead of (deprecated) `timestamp_nanos` [#7572](https://github.com/apache/arrow-datafusion/pull/7572) (alamb) +- Update sqllogictest requirement from 0.15.0 to 0.16.0 [#7569](https://github.com/apache/arrow-datafusion/pull/7569) (dependabot[bot]) +- extract `datafusion-physical-plan` to its own crate [#7432](https://github.com/apache/arrow-datafusion/pull/7432) (alamb) +- First and Last Accumulators should update with state row excluding is_set flag [#7565](https://github.com/apache/arrow-datafusion/pull/7565) (viirya) +- refactor: simplify code of eliminate_cross_join.rs [#7561](https://github.com/apache/arrow-datafusion/pull/7561) (jackwener) +- Update release instructions for datafusion-physical-plan crate [#7576](https://github.com/apache/arrow-datafusion/pull/7576) (alamb) +- Minor: Update chrono pin to `0.4.31` [#7575](https://github.com/apache/arrow-datafusion/pull/7575) (alamb) +- [feat] Introduce cacheManager in session ctx and make StatisticsCache share in session [#7570](https://github.com/apache/arrow-datafusion/pull/7570) (Ted-Jiang) +- Enhance/Refactor Ordering Equivalence Properties [#7566](https://github.com/apache/arrow-datafusion/pull/7566) (mustafasrepo) +- fix misplaced statements in sqllogictest [#7586](https://github.com/apache/arrow-datafusion/pull/7586) (jonahgao) +- Update substrait requirement from 0.13.1 to 0.14.0 [#7585](https://github.com/apache/arrow-datafusion/pull/7585) (dependabot[bot]) +- chore: use the `create_udwf` function in `simple_udwf`, consistent with `simple_udf` and `simple_udaf` [#7579](https://github.com/apache/arrow-datafusion/pull/7579) (tanruixiang) +- Implement protobuf serialization for AnalyzeExec [#7574](https://github.com/apache/arrow-datafusion/pull/7574) (adhish20) +- chore: fix catalog's usage docs error and add docs about `CatalogList` trait [#7582](https://github.com/apache/arrow-datafusion/pull/7582) (tanruixiang) +- Implement `CardinalityAwareRowConverter` while doing streaming merge [#7401](https://github.com/apache/arrow-datafusion/pull/7401) (JayjeetAtGithub) +- Parallelize Parquet Serialization [#7562](https://github.com/apache/arrow-datafusion/pull/7562) (devinjdangelo) +- feat: natively support more data types for the `abs` function. [#7568](https://github.com/apache/arrow-datafusion/pull/7568) (jonahgao) +- implement string_to_array [#7577](https://github.com/apache/arrow-datafusion/pull/7577) (casperhart) +- Create 2023 Q4 roadmap [#7551](https://github.com/apache/arrow-datafusion/pull/7551) (graydenshand) +- chore: reduce `physical-plan` dependencies [#7599](https://github.com/apache/arrow-datafusion/pull/7599) (crepererum) +- Minor: add githubs start/fork buttons to documentation page [#7588](https://github.com/apache/arrow-datafusion/pull/7588) (alamb) +- Minor: add more examples for `CREATE EXTERNAL TABLE` doc [#7594](https://github.com/apache/arrow-datafusion/pull/7594) (comphead) +- Update nix requirement from 0.26.1 to 0.27.1 [#7438](https://github.com/apache/arrow-datafusion/pull/7438) (dependabot[bot]) +- Update sqllogictest requirement from 0.16.0 to 0.17.0 [#7606](https://github.com/apache/arrow-datafusion/pull/7606) (dependabot[bot]) +- Fix panic in TopK [#7609](https://github.com/apache/arrow-datafusion/pull/7609) (avantgardnerio) +- Move `FileCompressionType` out of `common` and into `core` [#7596](https://github.com/apache/arrow-datafusion/pull/7596) (haohuaijin) +- Expose contents of Constraints [#7603](https://github.com/apache/arrow-datafusion/pull/7603) (tv42) +- Change the unbounded_output API default [#7605](https://github.com/apache/arrow-datafusion/pull/7605) (metesynnada) +- feat: Parallel collecting parquet files statistics #7573 [#7595](https://github.com/apache/arrow-datafusion/pull/7595) (hengfeiyang) +- Support hashing List columns [#7616](https://github.com/apache/arrow-datafusion/pull/7616) (jonmmease) +- [MINOR] Make the sink input aware of its plan [#7610](https://github.com/apache/arrow-datafusion/pull/7610) (metesynnada) +- [MINOR] Reduce complexity on SHJ [#7607](https://github.com/apache/arrow-datafusion/pull/7607) (metesynnada) +- feat: Better large output display in datafusion-cli with --maxrows option [#7617](https://github.com/apache/arrow-datafusion/pull/7617) (2010YOUY01) +- Minor: add examples for `arrow_cast` and `arrow_typeof` to user guide [#7615](https://github.com/apache/arrow-datafusion/pull/7615) (alamb) +- [MINOR]: Fix stack overflow bug for get field access expr [#7623](https://github.com/apache/arrow-datafusion/pull/7623) (mustafasrepo) +- Group By All [#7622](https://github.com/apache/arrow-datafusion/pull/7622) (berkaysynnada) +- Implement protobuf serialization for `(Bounded)WindowAggExec`. [#7557](https://github.com/apache/arrow-datafusion/pull/7557) (vrongmeal) +- Make it possible to compile datafusion-common without default features [#7625](https://github.com/apache/arrow-datafusion/pull/7625) (jonmmease) +- Minor: Adding backtrace documentation [#7628](https://github.com/apache/arrow-datafusion/pull/7628) (comphead) +- fix(5975/5976): timezone handling for timestamps and `date_trunc`, `date_part` and `date_bin` [#7614](https://github.com/apache/arrow-datafusion/pull/7614) (wiedld) +- Minor: remove unecessary `Arc`s in datetime_expressions [#7630](https://github.com/apache/arrow-datafusion/pull/7630) (alamb) +- fix: check for precision overflow when parsing float as decimal [#7627](https://github.com/apache/arrow-datafusion/pull/7627) (jonahgao) +- Update arrow 47.0.0 in DataFusion [#7587](https://github.com/apache/arrow-datafusion/pull/7587) (tustvold) +- Add test crate to compile DataFusion with wasm-pack [#7633](https://github.com/apache/arrow-datafusion/pull/7633) (jonmmease) +- Minor: Update documentation of case expression [#7646](https://github.com/apache/arrow-datafusion/pull/7646) (ongchi) +- Minor: improve docstrings on `SessionState` [#7654](https://github.com/apache/arrow-datafusion/pull/7654) (alamb) +- Update example in the DataFrame documentation. [#7650](https://github.com/apache/arrow-datafusion/pull/7650) (jsimpson-gro) +- Add HTTP object store example [#7602](https://github.com/apache/arrow-datafusion/pull/7602) (pka) +- feat: make parse_float_as_decimal work on negative numbers [#7648](https://github.com/apache/arrow-datafusion/pull/7648) (jonahgao) +- Minor: add doc comments to `ExtractEquijoinPredicate` [#7658](https://github.com/apache/arrow-datafusion/pull/7658) (alamb) +- [MINOR]: Do not add unnecessary hash repartition to the physical plan [#7667](https://github.com/apache/arrow-datafusion/pull/7667) (mustafasrepo) +- Minor: add ticket references to parallel parquet writing code [#7592](https://github.com/apache/arrow-datafusion/pull/7592) (alamb) +- Minor: Add ticket reference and add test comment [#7593](https://github.com/apache/arrow-datafusion/pull/7593) (alamb) +- Support Avro's Enum type and Fixed type [#7635](https://github.com/apache/arrow-datafusion/pull/7635) (sarutak) +- Minor: Migrate datafusion-proto tests into it own binary [#7668](https://github.com/apache/arrow-datafusion/pull/7668) (ongchi) +- Upgrade apache-avro to 0.16 [#7674](https://github.com/apache/arrow-datafusion/pull/7674) (sarutak) +- Move window analysis to the window method [#7672](https://github.com/apache/arrow-datafusion/pull/7672) (mustafasrepo) +- Don't add filters to projection in TableScan [#7670](https://github.com/apache/arrow-datafusion/pull/7670) (Dandandan) +- Minor: Improve `TableProviderFilterPushDown` docs [#7685](https://github.com/apache/arrow-datafusion/pull/7685) (alamb) +- FIX: Test timestamp with table [#7701](https://github.com/apache/arrow-datafusion/pull/7701) (jayzhan211) +- Fix bug in `SimplifyExpressions` [#7699](https://github.com/apache/arrow-datafusion/pull/7699) (Dandandan) +- Enhance Enforce Dist capabilities to fix, sub optimal bad plans [#7671](https://github.com/apache/arrow-datafusion/pull/7671) (mustafasrepo) +- docs: add section on supports_filters_pushdown [#7680](https://github.com/apache/arrow-datafusion/pull/7680) (tshauck) +- Improve cache usage in CI [#7678](https://github.com/apache/arrow-datafusion/pull/7678) (sarutak) +- fix: substrait limit when fetch is None [#7669](https://github.com/apache/arrow-datafusion/pull/7669) (waynexia) +- minor: revert parsing precedence between Aggr and UDAF [#7682](https://github.com/apache/arrow-datafusion/pull/7682) (waynexia) +- Minor: Move hash utils to common [#7684](https://github.com/apache/arrow-datafusion/pull/7684) (jayzhan211) +- Update Default Parquet Write Compression [#7692](https://github.com/apache/arrow-datafusion/pull/7692) (devinjdangelo) +- Stop using cache for the benchmark job [#7706](https://github.com/apache/arrow-datafusion/pull/7706) (sarutak) +- Change rust.yml to run benchmark [#7708](https://github.com/apache/arrow-datafusion/pull/7708) (sarutak) +- Extend infer_placeholder_types to support BETWEEN predicates [#7703](https://github.com/apache/arrow-datafusion/pull/7703) (andrelmartins) +- Minor: Add comment explaining why verify benchmark results uses release mode [#7712](https://github.com/apache/arrow-datafusion/pull/7712) (alamb) +- Support all the codecs supported by Avro [#7718](https://github.com/apache/arrow-datafusion/pull/7718) (sarutak) +- Update substrait requirement from 0.14.0 to 0.15.0 [#7719](https://github.com/apache/arrow-datafusion/pull/7719) (dependabot[bot]) +- fix: coerce text to timestamps with timezones [#7720](https://github.com/apache/arrow-datafusion/pull/7720) (mhilton) +- Add LanceDB to the list of Known Users [#7716](https://github.com/apache/arrow-datafusion/pull/7716) (alamb) +- Enable avro reading/writing in datafusion-cli [#7715](https://github.com/apache/arrow-datafusion/pull/7715) (alamb) +- Document crate feature flags [#7713](https://github.com/apache/arrow-datafusion/pull/7713) (alamb) +- Minor: Consolidate UDF tests [#7704](https://github.com/apache/arrow-datafusion/pull/7704) (alamb) +- Minor: fix CI failure due to Cargo.lock in datafusioncli [#7733](https://github.com/apache/arrow-datafusion/pull/7733) (yjshen) +- MINOR: change file to column index in page_filter trace log [#7730](https://github.com/apache/arrow-datafusion/pull/7730) (mapleFU) +- preserve array type / timezone in `date_bin` and `date_trunc` functions [#7729](https://github.com/apache/arrow-datafusion/pull/7729) (mhilton) +- Remove redundant is_numeric for DataType [#7734](https://github.com/apache/arrow-datafusion/pull/7734) (qrilka) +- fix: avro_to_arrow: Handle avro nested nullable struct (union) [#7663](https://github.com/apache/arrow-datafusion/pull/7663) (Samrose-Ahmed) +- Rename `SessionContext::with_config_rt` to `SessionContext::new_with_config_from_rt`, etc [#7631](https://github.com/apache/arrow-datafusion/pull/7631) (alamb) +- Rename `bounded_order_preserving_variants` config to `prefer_exising_sort` and update docs [#7723](https://github.com/apache/arrow-datafusion/pull/7723) (alamb) +- Optimize "ORDER BY + LIMIT" queries for speed / memory with special TopK operator [#7721](https://github.com/apache/arrow-datafusion/pull/7721) (Dandandan) +- Minor: Improve crate docs [#7740](https://github.com/apache/arrow-datafusion/pull/7740) (alamb) +- [MINOR]: Resolve linter errors in the main [#7753](https://github.com/apache/arrow-datafusion/pull/7753) (mustafasrepo) +- Minor: Build concat_internal() with ListArray construction instead of ArrayData [#7748](https://github.com/apache/arrow-datafusion/pull/7748) (jayzhan211) +- Minor: Add comment on input_schema from AggregateExec [#7727](https://github.com/apache/arrow-datafusion/pull/7727) (viirya) +- Fix column name for COUNT(\*) set by AggregateStatistics [#7757](https://github.com/apache/arrow-datafusion/pull/7757) (qrilka) +- Add documentation about type signatures, and export `TIMEZONE_WILDCARD` [#7726](https://github.com/apache/arrow-datafusion/pull/7726) (alamb) +- [feat] Support cache ListFiles result cache in session level [#7620](https://github.com/apache/arrow-datafusion/pull/7620) (Ted-Jiang) +- Support `SHOW ALL VERBOSE` to show settings description [#7735](https://github.com/apache/arrow-datafusion/pull/7735) (comphead) diff --git a/dev/changelog/33.0.0.md b/dev/changelog/33.0.0.md new file mode 100644 index 000000000000..17862a64a951 --- /dev/null +++ b/dev/changelog/33.0.0.md @@ -0,0 +1,292 @@ + + +## [33.0.0](https://github.com/apache/arrow-datafusion/tree/33.0.0) (2023-11-12) + +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/32.0.0...33.0.0) + +**Breaking changes:** + +- Refactor Statistics, introduce precision estimates (`Exact`, `Inexact`, `Absent`) [#7793](https://github.com/apache/arrow-datafusion/pull/7793) (berkaysynnada) +- Remove redundant unwrap in `ScalarValue::new_primitive`, return a `Result` [#7830](https://github.com/apache/arrow-datafusion/pull/7830) (maruschin) +- Add `parquet` feature flag, enabled by default, and make parquet conditional [#7745](https://github.com/apache/arrow-datafusion/pull/7745) (ongchi) +- Change input for `to_timestamp` function to be seconds rather than nanoseconds, add `to_timestamp_nanos` [#7844](https://github.com/apache/arrow-datafusion/pull/7844) (comphead) +- Percent Decode URL Paths (#8009) [#8012](https://github.com/apache/arrow-datafusion/pull/8012) (tustvold) +- chore: remove panics in datafusion-common::scalar by making more operations return `Result` [#7901](https://github.com/apache/arrow-datafusion/pull/7901) (junjunjd) +- Combine `Expr::Wildcard` and `Wxpr::QualifiedWildcard`, add `wildcard()` expr fn [#8105](https://github.com/apache/arrow-datafusion/pull/8105) (alamb) + +**Performance related:** + +- Add distinct union optimization [#7788](https://github.com/apache/arrow-datafusion/pull/7788) (maruschin) +- Fix join order for TPCH Q17 & Q18 by improving FilterExec statistics [#8126](https://github.com/apache/arrow-datafusion/pull/8126) (andygrove) +- feat: add column statistics into explain [#8112](https://github.com/apache/arrow-datafusion/pull/8112) (NGA-TRAN) + +**Implemented enhancements:** + +- Support InsertInto Sorted ListingTable [#7743](https://github.com/apache/arrow-datafusion/pull/7743) (devinjdangelo) +- External Table Primary key support [#7755](https://github.com/apache/arrow-datafusion/pull/7755) (mustafasrepo) +- add interval arithmetic for timestamp types [#7758](https://github.com/apache/arrow-datafusion/pull/7758) (mhilton) +- Interval Arithmetic NegativeExpr Support [#7804](https://github.com/apache/arrow-datafusion/pull/7804) (berkaysynnada) +- Exactness Indicator of Parameters: Precision [#7809](https://github.com/apache/arrow-datafusion/pull/7809) (berkaysynnada) +- Implement GetIndexedField for map-typed columns [#7825](https://github.com/apache/arrow-datafusion/pull/7825) (swgillespie) +- Fix precision loss when coercing date_part utf8 argument [#7846](https://github.com/apache/arrow-datafusion/pull/7846) (Dandandan) +- Support `Binary`/`LargeBinary` --> `Utf8`/`LargeUtf8` in ilike and string functions [#7840](https://github.com/apache/arrow-datafusion/pull/7840) (alamb) +- Support Decimal256 on AVG aggregate expression [#7853](https://github.com/apache/arrow-datafusion/pull/7853) (viirya) +- Support Decimal256 column in create external table [#7866](https://github.com/apache/arrow-datafusion/pull/7866) (viirya) +- Support Decimal256 in Min/Max aggregate expressions [#7881](https://github.com/apache/arrow-datafusion/pull/7881) (viirya) +- Implement Hive-Style Partitioned Write Support [#7801](https://github.com/apache/arrow-datafusion/pull/7801) (devinjdangelo) +- feat: support `Decimal256` for the `abs` function [#7904](https://github.com/apache/arrow-datafusion/pull/7904) (jonahgao) +- Parallelize Serialization of Columns within Parquet RowGroups [#7655](https://github.com/apache/arrow-datafusion/pull/7655) (devinjdangelo) +- feat: Use bloom filter when reading parquet to skip row groups [#7821](https://github.com/apache/arrow-datafusion/pull/7821) (hengfeiyang) +- Support Partitioning Data by Dictionary Encoded String Array Types [#7896](https://github.com/apache/arrow-datafusion/pull/7896) (devinjdangelo) +- Read only enough bytes to infer Arrow IPC file schema via stream [#7962](https://github.com/apache/arrow-datafusion/pull/7962) (Jefffrey) +- feat: Support determining extensions from names like `foo.parquet.snappy` as well as `foo.parquet` [#7972](https://github.com/apache/arrow-datafusion/pull/7972) (Weijun-H) +- feat: Protobuf serde for Json file sink [#8062](https://github.com/apache/arrow-datafusion/pull/8062) (Jefffrey) +- feat: support target table alias in update statement [#8080](https://github.com/apache/arrow-datafusion/pull/8080) (jonahgao) +- feat: support UDAF in substrait producer/consumer [#8119](https://github.com/apache/arrow-datafusion/pull/8119) (waynexia) + +**Fixed bugs:** + +- fix: preserve column qualifier for `DataFrame::with_column` [#7792](https://github.com/apache/arrow-datafusion/pull/7792) (jonahgao) +- fix: don't push down volatile predicates in projection [#7909](https://github.com/apache/arrow-datafusion/pull/7909) (haohuaijin) +- fix: generate logical plan for `UPDATE SET FROM` statement [#7984](https://github.com/apache/arrow-datafusion/pull/7984) (jonahgao) +- fix: single_distinct_aggretation_to_group_by fail [#7997](https://github.com/apache/arrow-datafusion/pull/7997) (haohuaijin) +- fix: clippy warnings from nightly rust 1.75 [#8025](https://github.com/apache/arrow-datafusion/pull/8025) (waynexia) +- fix: DataFusion suggests invalid functions [#8083](https://github.com/apache/arrow-datafusion/pull/8083) (jonahgao) +- fix: add encode/decode to protobuf encoding [#8089](https://github.com/apache/arrow-datafusion/pull/8089) (Syleechan) + +**Documentation updates:** + +- Minor: Improve TableProvider document, and add ascii art [#7759](https://github.com/apache/arrow-datafusion/pull/7759) (alamb) +- Expose arrow-schema `serde` crate feature flag [#7829](https://github.com/apache/arrow-datafusion/pull/7829) (lewiszlw) +- doc: fix ExecutionContext to SessionContext in custom-table-providers.md [#7903](https://github.com/apache/arrow-datafusion/pull/7903) (ZENOTME) +- Minor: Document `parquet` crate feature [#7927](https://github.com/apache/arrow-datafusion/pull/7927) (alamb) +- Add some initial content about creating logical plans [#7952](https://github.com/apache/arrow-datafusion/pull/7952) (andygrove) +- Minor: Add implementation examples to ExecutionPlan::execute [#8013](https://github.com/apache/arrow-datafusion/pull/8013) (tustvold) +- Minor: Improve documentation for Filter Pushdown [#8023](https://github.com/apache/arrow-datafusion/pull/8023) (alamb) +- Minor: Improve `ExecutionPlan` documentation [#8019](https://github.com/apache/arrow-datafusion/pull/8019) (alamb) +- Improve comments for `PartitionSearchMode` struct [#8047](https://github.com/apache/arrow-datafusion/pull/8047) (ozankabak) +- Prepare 33.0.0 Release [#8057](https://github.com/apache/arrow-datafusion/pull/8057) (andygrove) +- Improve documentation for calculate_prune_length method in `SymmetricHashJoin` [#8125](https://github.com/apache/arrow-datafusion/pull/8125) (Asura7969) +- docs: show creation of DFSchema [#8132](https://github.com/apache/arrow-datafusion/pull/8132) (wjones127) +- Improve documentation site to make it easier to find communication on Slack/Discord [#8138](https://github.com/apache/arrow-datafusion/pull/8138) (alamb) + +**Merged pull requests:** + +- Minor: Improve TableProvider document, and add ascii art [#7759](https://github.com/apache/arrow-datafusion/pull/7759) (alamb) +- Prepare 32.0.0 Release [#7769](https://github.com/apache/arrow-datafusion/pull/7769) (andygrove) +- Minor: Change all file links to GitHub in document [#7768](https://github.com/apache/arrow-datafusion/pull/7768) (ongchi) +- Minor: Improve `PruningPredicate` documentation [#7738](https://github.com/apache/arrow-datafusion/pull/7738) (alamb) +- Support InsertInto Sorted ListingTable [#7743](https://github.com/apache/arrow-datafusion/pull/7743) (devinjdangelo) +- Minor: improve documentation to `stagger_batch` [#7754](https://github.com/apache/arrow-datafusion/pull/7754) (alamb) +- External Table Primary key support [#7755](https://github.com/apache/arrow-datafusion/pull/7755) (mustafasrepo) +- Minor: Build array_array() with ListArray construction instead of ArrayData [#7780](https://github.com/apache/arrow-datafusion/pull/7780) (jayzhan211) +- Minor: Remove unnecessary `#[cfg(feature = "avro")]` [#7773](https://github.com/apache/arrow-datafusion/pull/7773) (sarutak) +- add interval arithmetic for timestamp types [#7758](https://github.com/apache/arrow-datafusion/pull/7758) (mhilton) +- Minor: make tests deterministic [#7771](https://github.com/apache/arrow-datafusion/pull/7771) (Weijun-H) +- Minor: Improve `Interval` Docs [#7782](https://github.com/apache/arrow-datafusion/pull/7782) (alamb) +- `DataSink` additions [#7778](https://github.com/apache/arrow-datafusion/pull/7778) (Dandandan) +- Update substrait requirement from 0.15.0 to 0.16.0 [#7783](https://github.com/apache/arrow-datafusion/pull/7783) (dependabot[bot]) +- Move nested union optimization from plan builder to logical optimizer [#7695](https://github.com/apache/arrow-datafusion/pull/7695) (maruschin) +- Minor: comments that explain the schema used in simply_expressions [#7747](https://github.com/apache/arrow-datafusion/pull/7747) (alamb) +- Update regex-syntax requirement from 0.7.1 to 0.8.0 [#7784](https://github.com/apache/arrow-datafusion/pull/7784) (dependabot[bot]) +- Minor: Add sql test for `UNION` / `UNION ALL` + plans [#7787](https://github.com/apache/arrow-datafusion/pull/7787) (alamb) +- fix: preserve column qualifier for `DataFrame::with_column` [#7792](https://github.com/apache/arrow-datafusion/pull/7792) (jonahgao) +- Interval Arithmetic NegativeExpr Support [#7804](https://github.com/apache/arrow-datafusion/pull/7804) (berkaysynnada) +- Exactness Indicator of Parameters: Precision [#7809](https://github.com/apache/arrow-datafusion/pull/7809) (berkaysynnada) +- add `LogicalPlanBuilder::join_on` [#7805](https://github.com/apache/arrow-datafusion/pull/7805) (haohuaijin) +- Fix SortPreservingRepartition with no existing ordering. [#7811](https://github.com/apache/arrow-datafusion/pull/7811) (mustafasrepo) +- Update zstd requirement from 0.12 to 0.13 [#7806](https://github.com/apache/arrow-datafusion/pull/7806) (dependabot[bot]) +- [Minor]: Remove input_schema field from window executor [#7810](https://github.com/apache/arrow-datafusion/pull/7810) (mustafasrepo) +- refactor(7181): move streaming_merge() into separate mod from the merge node [#7799](https://github.com/apache/arrow-datafusion/pull/7799) (wiedld) +- Improve update error [#7777](https://github.com/apache/arrow-datafusion/pull/7777) (lewiszlw) +- Minor: Update LogicalPlan::join_on API, use it more [#7814](https://github.com/apache/arrow-datafusion/pull/7814) (alamb) +- Add distinct union optimization [#7788](https://github.com/apache/arrow-datafusion/pull/7788) (maruschin) +- Make CI fail on any occurrence of rust-tomlfmt failed [#7774](https://github.com/apache/arrow-datafusion/pull/7774) (ongchi) +- Encode all join conditions in a single expression field [#7612](https://github.com/apache/arrow-datafusion/pull/7612) (nseekhao) +- Update substrait requirement from 0.16.0 to 0.17.0 [#7808](https://github.com/apache/arrow-datafusion/pull/7808) (dependabot[bot]) +- Minor: include `sort` expressions in `SortPreservingRepartitionExec` explain plan [#7796](https://github.com/apache/arrow-datafusion/pull/7796) (alamb) +- minor: add more document to Wildcard expr [#7822](https://github.com/apache/arrow-datafusion/pull/7822) (waynexia) +- Minor: Move `Monotonicity` to `expr` crate [#7820](https://github.com/apache/arrow-datafusion/pull/7820) (2010YOUY01) +- Use code block for better formatting of rustdoc for PhysicalGroupBy [#7823](https://github.com/apache/arrow-datafusion/pull/7823) (qrilka) +- Update explain plan to show `TopK` operator [#7826](https://github.com/apache/arrow-datafusion/pull/7826) (haohuaijin) +- Extract ReceiverStreamBuilder [#7817](https://github.com/apache/arrow-datafusion/pull/7817) (tustvold) +- Extend backtrace coverage for `DatafusionError::Plan` errors errors [#7803](https://github.com/apache/arrow-datafusion/pull/7803) (comphead) +- Add documentation and usability for prepared parameters [#7785](https://github.com/apache/arrow-datafusion/pull/7785) (alamb) +- Implement GetIndexedField for map-typed columns [#7825](https://github.com/apache/arrow-datafusion/pull/7825) (swgillespie) +- Minor: Assert `streaming_merge` has non empty sort exprs [#7795](https://github.com/apache/arrow-datafusion/pull/7795) (alamb) +- Minor: Upgrade docs for `PhysicalExpr::{propagate_constraints, evaluate_bounds}` [#7812](https://github.com/apache/arrow-datafusion/pull/7812) (alamb) +- Change ScalarValue::List to store ArrayRef [#7629](https://github.com/apache/arrow-datafusion/pull/7629) (jayzhan211) +- [MINOR]:Do not introduce unnecessary repartition when row count is 1. [#7832](https://github.com/apache/arrow-datafusion/pull/7832) (mustafasrepo) +- Minor: Add tests for binary / utf8 coercion [#7839](https://github.com/apache/arrow-datafusion/pull/7839) (alamb) +- Avoid panics on error while encoding/decoding ListValue::Array as protobuf [#7837](https://github.com/apache/arrow-datafusion/pull/7837) (alamb) +- Refactor Statistics, introduce precision estimates (`Exact`, `Inexact`, `Absent`) [#7793](https://github.com/apache/arrow-datafusion/pull/7793) (berkaysynnada) +- Remove redundant unwrap in `ScalarValue::new_primitive`, return a `Result` [#7830](https://github.com/apache/arrow-datafusion/pull/7830) (maruschin) +- Fix precision loss when coercing date_part utf8 argument [#7846](https://github.com/apache/arrow-datafusion/pull/7846) (Dandandan) +- Add operator section to user guide, Add `std::ops` operations to `prelude`, and add `not()` expr_fn [#7732](https://github.com/apache/arrow-datafusion/pull/7732) (ongchi) +- Expose arrow-schema `serde` crate feature flag [#7829](https://github.com/apache/arrow-datafusion/pull/7829) (lewiszlw) +- Improve `ContextProvider` naming: rename` get_table_provider` --> `get_table_source`, deprecate `get_table_provider` [#7831](https://github.com/apache/arrow-datafusion/pull/7831) (lewiszlw) +- DataSink Dynamic Execution Time Demux [#7791](https://github.com/apache/arrow-datafusion/pull/7791) (devinjdangelo) +- Add small column on empty projection [#7833](https://github.com/apache/arrow-datafusion/pull/7833) (ch-sc) +- feat(7849): coerce TIMESTAMP to TIMESTAMPTZ [#7850](https://github.com/apache/arrow-datafusion/pull/7850) (mhilton) +- Support `Binary`/`LargeBinary` --> `Utf8`/`LargeUtf8` in ilike and string functions [#7840](https://github.com/apache/arrow-datafusion/pull/7840) (alamb) +- Minor: fix typo in comments [#7856](https://github.com/apache/arrow-datafusion/pull/7856) (haohuaijin) +- Minor: improve `join` / `join_on` docs [#7813](https://github.com/apache/arrow-datafusion/pull/7813) (alamb) +- Support Decimal256 on AVG aggregate expression [#7853](https://github.com/apache/arrow-datafusion/pull/7853) (viirya) +- Minor: fix typo in comments [#7861](https://github.com/apache/arrow-datafusion/pull/7861) (alamb) +- Minor: fix typo in GreedyMemoryPool documentation [#7864](https://github.com/apache/arrow-datafusion/pull/7864) (avh4) +- Minor: fix multiple typos [#7863](https://github.com/apache/arrow-datafusion/pull/7863) (Smoothieewastaken) +- Minor: Fix docstring typos [#7873](https://github.com/apache/arrow-datafusion/pull/7873) (alamb) +- Add CursorValues Decoupling Cursor Data from Cursor Position [#7855](https://github.com/apache/arrow-datafusion/pull/7855) (tustvold) +- Support Decimal256 column in create external table [#7866](https://github.com/apache/arrow-datafusion/pull/7866) (viirya) +- Support Decimal256 in Min/Max aggregate expressions [#7881](https://github.com/apache/arrow-datafusion/pull/7881) (viirya) +- Implement Hive-Style Partitioned Write Support [#7801](https://github.com/apache/arrow-datafusion/pull/7801) (devinjdangelo) +- Minor: fix config typo [#7874](https://github.com/apache/arrow-datafusion/pull/7874) (alamb) +- Add Decimal256 sqllogictests for SUM, MEDIAN and COUNT aggregate expressions [#7889](https://github.com/apache/arrow-datafusion/pull/7889) (viirya) +- [test] add fuzz test for topk [#7772](https://github.com/apache/arrow-datafusion/pull/7772) (Tangruilin) +- Allow Setting Minimum Parallelism with RowCount Based Demuxer [#7841](https://github.com/apache/arrow-datafusion/pull/7841) (devinjdangelo) +- Drop single quotes to make warnings for parquet options not confusing [#7902](https://github.com/apache/arrow-datafusion/pull/7902) (qrilka) +- Add multi-column topk fuzz tests [#7898](https://github.com/apache/arrow-datafusion/pull/7898) (alamb) +- Change `FileScanConfig.table_partition_cols` from `(String, DataType)` to `Field`s [#7890](https://github.com/apache/arrow-datafusion/pull/7890) (NGA-TRAN) +- Maintain time zone in `ScalarValue::new_list` [#7899](https://github.com/apache/arrow-datafusion/pull/7899) (Dandandan) +- [MINOR]: Move joinside struct to common [#7908](https://github.com/apache/arrow-datafusion/pull/7908) (mustafasrepo) +- doc: fix ExecutionContext to SessionContext in custom-table-providers.md [#7903](https://github.com/apache/arrow-datafusion/pull/7903) (ZENOTME) +- Update arrow 48.0.0 [#7854](https://github.com/apache/arrow-datafusion/pull/7854) (tustvold) +- feat: support `Decimal256` for the `abs` function [#7904](https://github.com/apache/arrow-datafusion/pull/7904) (jonahgao) +- [MINOR] Simplify Aggregate, and Projection output_partitioning implementation [#7907](https://github.com/apache/arrow-datafusion/pull/7907) (mustafasrepo) +- Bump actions/setup-node from 3 to 4 [#7915](https://github.com/apache/arrow-datafusion/pull/7915) (dependabot[bot]) +- [Bug Fix]: Fix bug, first last reverse [#7914](https://github.com/apache/arrow-datafusion/pull/7914) (mustafasrepo) +- Minor: provide default implementation for ExecutionPlan::statistics [#7911](https://github.com/apache/arrow-datafusion/pull/7911) (alamb) +- Update substrait requirement from 0.17.0 to 0.18.0 [#7916](https://github.com/apache/arrow-datafusion/pull/7916) (dependabot[bot]) +- Minor: Remove unnecessary clone in datafusion_proto [#7921](https://github.com/apache/arrow-datafusion/pull/7921) (ongchi) +- [MINOR]: Simplify code, change requirement from PhysicalSortExpr to PhysicalSortRequirement [#7913](https://github.com/apache/arrow-datafusion/pull/7913) (mustafasrepo) +- [Minor] Move combine_join util to under equivalence.rs [#7917](https://github.com/apache/arrow-datafusion/pull/7917) (mustafasrepo) +- support scan empty projection [#7920](https://github.com/apache/arrow-datafusion/pull/7920) (haohuaijin) +- Cleanup logical optimizer rules. [#7919](https://github.com/apache/arrow-datafusion/pull/7919) (mustafasrepo) +- Parallelize Serialization of Columns within Parquet RowGroups [#7655](https://github.com/apache/arrow-datafusion/pull/7655) (devinjdangelo) +- feat: Use bloom filter when reading parquet to skip row groups [#7821](https://github.com/apache/arrow-datafusion/pull/7821) (hengfeiyang) +- fix: don't push down volatile predicates in projection [#7909](https://github.com/apache/arrow-datafusion/pull/7909) (haohuaijin) +- Add `parquet` feature flag, enabled by default, and make parquet conditional [#7745](https://github.com/apache/arrow-datafusion/pull/7745) (ongchi) +- [MINOR]: Simplify enforce_distribution, minor changes [#7924](https://github.com/apache/arrow-datafusion/pull/7924) (mustafasrepo) +- Add simple window query to sqllogictest [#7928](https://github.com/apache/arrow-datafusion/pull/7928) (Jefffrey) +- ci: upgrade node to version 20 [#7918](https://github.com/apache/arrow-datafusion/pull/7918) (crepererum) +- Change input for `to_timestamp` function to be seconds rather than nanoseconds, add `to_timestamp_nanos` [#7844](https://github.com/apache/arrow-datafusion/pull/7844) (comphead) +- Minor: Document `parquet` crate feature [#7927](https://github.com/apache/arrow-datafusion/pull/7927) (alamb) +- Minor: reduce some `#cfg(feature = "parquet")` [#7929](https://github.com/apache/arrow-datafusion/pull/7929) (alamb) +- Minor: reduce use of `#cfg(feature = "parquet")` in tests [#7930](https://github.com/apache/arrow-datafusion/pull/7930) (alamb) +- Fix CI failures on `to_timestamp()` calls [#7941](https://github.com/apache/arrow-datafusion/pull/7941) (comphead) +- minor: add a datatype casting for the updated value [#7922](https://github.com/apache/arrow-datafusion/pull/7922) (jonahgao) +- Minor:add `avro` feature in datafusion-examples to make `avro_sql` run [#7946](https://github.com/apache/arrow-datafusion/pull/7946) (haohuaijin) +- Add simple exclude all columns test to sqllogictest [#7945](https://github.com/apache/arrow-datafusion/pull/7945) (Jefffrey) +- Support Partitioning Data by Dictionary Encoded String Array Types [#7896](https://github.com/apache/arrow-datafusion/pull/7896) (devinjdangelo) +- Minor: Remove array() in array_expression [#7961](https://github.com/apache/arrow-datafusion/pull/7961) (jayzhan211) +- Minor: simplify update code [#7943](https://github.com/apache/arrow-datafusion/pull/7943) (alamb) +- Add some initial content about creating logical plans [#7952](https://github.com/apache/arrow-datafusion/pull/7952) (andygrove) +- Minor: Change from `&mut SessionContext` to `&SessionContext` in substrait [#7965](https://github.com/apache/arrow-datafusion/pull/7965) (my-vegetable-has-exploded) +- Fix crate READMEs [#7964](https://github.com/apache/arrow-datafusion/pull/7964) (Jefffrey) +- Minor: Improve `HashJoinExec` documentation [#7953](https://github.com/apache/arrow-datafusion/pull/7953) (alamb) +- chore: clean useless clone baesd on clippy [#7973](https://github.com/apache/arrow-datafusion/pull/7973) (Weijun-H) +- Add README.md to `core`, `execution` and `physical-plan` crates [#7970](https://github.com/apache/arrow-datafusion/pull/7970) (alamb) +- Move source repartitioning into `ExecutionPlan::repartition` [#7936](https://github.com/apache/arrow-datafusion/pull/7936) (alamb) +- minor: fix broken links in README.md [#7986](https://github.com/apache/arrow-datafusion/pull/7986) (jonahgao) +- Minor: Upate the `sqllogictest` crate README [#7971](https://github.com/apache/arrow-datafusion/pull/7971) (alamb) +- Improve MemoryCatalogProvider default impl block placement [#7975](https://github.com/apache/arrow-datafusion/pull/7975) (lewiszlw) +- Fix `ScalarValue` handling of NULL values for ListArray [#7969](https://github.com/apache/arrow-datafusion/pull/7969) (viirya) +- Refactor of Ordering and Prunability Traversals and States [#7985](https://github.com/apache/arrow-datafusion/pull/7985) (berkaysynnada) +- Keep output as scalar for scalar function if all inputs are scalar [#7967](https://github.com/apache/arrow-datafusion/pull/7967) (viirya) +- Fix crate READMEs for core, execution, physical-plan [#7990](https://github.com/apache/arrow-datafusion/pull/7990) (Jefffrey) +- Update sqlparser requirement from 0.38.0 to 0.39.0 [#7983](https://github.com/apache/arrow-datafusion/pull/7983) (jackwener) +- Fix panic in multiple distinct aggregates by fixing `ScalarValue::new_list` [#7989](https://github.com/apache/arrow-datafusion/pull/7989) (alamb) +- Minor: Add `MemoryReservation::consumer` getter [#8000](https://github.com/apache/arrow-datafusion/pull/8000) (milenkovicm) +- fix: generate logical plan for `UPDATE SET FROM` statement [#7984](https://github.com/apache/arrow-datafusion/pull/7984) (jonahgao) +- Create temporary files for reading or writing [#8005](https://github.com/apache/arrow-datafusion/pull/8005) (smallzhongfeng) +- Minor: fix comment on SortExec::with_fetch method [#8011](https://github.com/apache/arrow-datafusion/pull/8011) (westonpace) +- Fix: dataframe_subquery example Optimizer rule `common_sub_expression_eliminate` failed [#8016](https://github.com/apache/arrow-datafusion/pull/8016) (smallzhongfeng) +- Percent Decode URL Paths (#8009) [#8012](https://github.com/apache/arrow-datafusion/pull/8012) (tustvold) +- Minor: Extract common deps into workspace [#7982](https://github.com/apache/arrow-datafusion/pull/7982) (lewiszlw) +- minor: change some plan_err to exec_err [#7996](https://github.com/apache/arrow-datafusion/pull/7996) (waynexia) +- Minor: error on unsupported RESPECT NULLs syntax [#7998](https://github.com/apache/arrow-datafusion/pull/7998) (alamb) +- Break GroupedHashAggregateStream spill batch into smaller chunks [#8004](https://github.com/apache/arrow-datafusion/pull/8004) (milenkovicm) +- Minor: Add implementation examples to ExecutionPlan::execute [#8013](https://github.com/apache/arrow-datafusion/pull/8013) (tustvold) +- Minor: Extend wrap_into_list_array to accept multiple args [#7993](https://github.com/apache/arrow-datafusion/pull/7993) (jayzhan211) +- GroupedHashAggregateStream should register spillable consumer [#8002](https://github.com/apache/arrow-datafusion/pull/8002) (milenkovicm) +- fix: single_distinct_aggretation_to_group_by fail [#7997](https://github.com/apache/arrow-datafusion/pull/7997) (haohuaijin) +- Read only enough bytes to infer Arrow IPC file schema via stream [#7962](https://github.com/apache/arrow-datafusion/pull/7962) (Jefffrey) +- Minor: remove a strange char [#8030](https://github.com/apache/arrow-datafusion/pull/8030) (haohuaijin) +- Minor: Improve documentation for Filter Pushdown [#8023](https://github.com/apache/arrow-datafusion/pull/8023) (alamb) +- Minor: Improve `ExecutionPlan` documentation [#8019](https://github.com/apache/arrow-datafusion/pull/8019) (alamb) +- fix: clippy warnings from nightly rust 1.75 [#8025](https://github.com/apache/arrow-datafusion/pull/8025) (waynexia) +- Minor: Avoid recomputing compute_array_ndims in align_array_dimensions [#7963](https://github.com/apache/arrow-datafusion/pull/7963) (jayzhan211) +- Minor: fix doc and fmt CI check [#8037](https://github.com/apache/arrow-datafusion/pull/8037) (alamb) +- Minor: remove uncessary #cfg test [#8036](https://github.com/apache/arrow-datafusion/pull/8036) (alamb) +- Minor: Improve documentation for `PartitionStream` and `StreamingTableExec` [#8035](https://github.com/apache/arrow-datafusion/pull/8035) (alamb) +- Combine Equivalence and Ordering equivalence to simplify state [#8006](https://github.com/apache/arrow-datafusion/pull/8006) (mustafasrepo) +- Encapsulate `ProjectionMapping` as a struct [#8033](https://github.com/apache/arrow-datafusion/pull/8033) (alamb) +- Minor: Fix bugs in docs for `to_timestamp`, `to_timestamp_seconds`, ... [#8040](https://github.com/apache/arrow-datafusion/pull/8040) (alamb) +- Improve comments for `PartitionSearchMode` struct [#8047](https://github.com/apache/arrow-datafusion/pull/8047) (ozankabak) +- General approach for Array replace [#8050](https://github.com/apache/arrow-datafusion/pull/8050) (jayzhan211) +- Minor: Remove the irrelevant note from the Expression API doc [#8053](https://github.com/apache/arrow-datafusion/pull/8053) (ongchi) +- Minor: Add more documentation about Partitioning [#8022](https://github.com/apache/arrow-datafusion/pull/8022) (alamb) +- Minor: improve documentation for IsNotNull, DISTINCT, etc [#8052](https://github.com/apache/arrow-datafusion/pull/8052) (alamb) +- Prepare 33.0.0 Release [#8057](https://github.com/apache/arrow-datafusion/pull/8057) (andygrove) +- Minor: improve error message by adding types to message [#8065](https://github.com/apache/arrow-datafusion/pull/8065) (alamb) +- Minor: Remove redundant BuiltinScalarFunction::supports_zero_argument() [#8059](https://github.com/apache/arrow-datafusion/pull/8059) (2010YOUY01) +- Add example to ci [#8060](https://github.com/apache/arrow-datafusion/pull/8060) (smallzhongfeng) +- Update substrait requirement from 0.18.0 to 0.19.0 [#8076](https://github.com/apache/arrow-datafusion/pull/8076) (dependabot[bot]) +- Fix incorrect results in COUNT(\*) queries with LIMIT [#8049](https://github.com/apache/arrow-datafusion/pull/8049) (msirek) +- feat: Support determining extensions from names like `foo.parquet.snappy` as well as `foo.parquet` [#7972](https://github.com/apache/arrow-datafusion/pull/7972) (Weijun-H) +- Use FairSpillPool for TaskContext with spillable config [#8072](https://github.com/apache/arrow-datafusion/pull/8072) (viirya) +- Minor: Improve HashJoinStream docstrings [#8070](https://github.com/apache/arrow-datafusion/pull/8070) (alamb) +- Fixing broken link [#8085](https://github.com/apache/arrow-datafusion/pull/8085) (edmondop) +- fix: DataFusion suggests invalid functions [#8083](https://github.com/apache/arrow-datafusion/pull/8083) (jonahgao) +- Replace macro with function for `array_repeat` [#8071](https://github.com/apache/arrow-datafusion/pull/8071) (jayzhan211) +- Minor: remove unnecessary projection in `single_distinct_to_group_by` rule [#8061](https://github.com/apache/arrow-datafusion/pull/8061) (haohuaijin) +- minor: Remove duplicate version numbers for arrow, object_store, and parquet dependencies [#8095](https://github.com/apache/arrow-datafusion/pull/8095) (andygrove) +- fix: add encode/decode to protobuf encoding [#8089](https://github.com/apache/arrow-datafusion/pull/8089) (Syleechan) +- feat: Protobuf serde for Json file sink [#8062](https://github.com/apache/arrow-datafusion/pull/8062) (Jefffrey) +- Minor: use `Expr::alias` in a few places to make the code more concise [#8097](https://github.com/apache/arrow-datafusion/pull/8097) (alamb) +- Minor: Cleanup BuiltinScalarFunction::return_type() [#8088](https://github.com/apache/arrow-datafusion/pull/8088) (2010YOUY01) +- Update sqllogictest requirement from 0.17.0 to 0.18.0 [#8102](https://github.com/apache/arrow-datafusion/pull/8102) (dependabot[bot]) +- Projection Pushdown in PhysicalPlan [#8073](https://github.com/apache/arrow-datafusion/pull/8073) (berkaysynnada) +- Push limit into aggregation for DISTINCT ... LIMIT queries [#8038](https://github.com/apache/arrow-datafusion/pull/8038) (msirek) +- Bug-fix in Filter and Limit statistics [#8094](https://github.com/apache/arrow-datafusion/pull/8094) (berkaysynnada) +- feat: support target table alias in update statement [#8080](https://github.com/apache/arrow-datafusion/pull/8080) (jonahgao) +- Minor: Simlify downcast functions in cast.rs. [#8103](https://github.com/apache/arrow-datafusion/pull/8103) (Weijun-H) +- Fix ArrayAgg schema mismatch issue [#8055](https://github.com/apache/arrow-datafusion/pull/8055) (jayzhan211) +- Minor: Support `nulls` in `array_replace`, avoid a copy [#8054](https://github.com/apache/arrow-datafusion/pull/8054) (alamb) +- Minor: Improve the document format of JoinHashMap [#8090](https://github.com/apache/arrow-datafusion/pull/8090) (Asura7969) +- Simplify ProjectionPushdown and make it more general [#8109](https://github.com/apache/arrow-datafusion/pull/8109) (alamb) +- Minor: clean up the code regarding clippy [#8122](https://github.com/apache/arrow-datafusion/pull/8122) (Weijun-H) +- Support remaining functions in protobuf serialization, add `expr_fn` for `StructFunction` [#8100](https://github.com/apache/arrow-datafusion/pull/8100) (JacobOgle) +- Minor: Cleanup BuiltinScalarFunction's phys-expr creation [#8114](https://github.com/apache/arrow-datafusion/pull/8114) (2010YOUY01) +- rewrite `array_append/array_prepend` to remove deplicate codes [#8108](https://github.com/apache/arrow-datafusion/pull/8108) (Veeupup) +- Implementation of `array_intersect` [#8081](https://github.com/apache/arrow-datafusion/pull/8081) (Veeupup) +- Minor: fix ci break [#8136](https://github.com/apache/arrow-datafusion/pull/8136) (haohuaijin) +- Improve documentation for calculate_prune_length method in `SymmetricHashJoin` [#8125](https://github.com/apache/arrow-datafusion/pull/8125) (Asura7969) +- Minor: remove duplicated `array_replace` tests [#8066](https://github.com/apache/arrow-datafusion/pull/8066) (alamb) +- Minor: Fix temporary files created but not deleted during testing [#8115](https://github.com/apache/arrow-datafusion/pull/8115) (2010YOUY01) +- chore: remove panics in datafusion-common::scalar by making more operations return `Result` [#7901](https://github.com/apache/arrow-datafusion/pull/7901) (junjunjd) +- Fix join order for TPCH Q17 & Q18 by improving FilterExec statistics [#8126](https://github.com/apache/arrow-datafusion/pull/8126) (andygrove) +- Fix: Do not try and preserve order when there is no order to preserve in RepartitionExec [#8127](https://github.com/apache/arrow-datafusion/pull/8127) (alamb) +- feat: add column statistics into explain [#8112](https://github.com/apache/arrow-datafusion/pull/8112) (NGA-TRAN) +- Add subtrait support for `IS NULL` and `IS NOT NULL` [#8093](https://github.com/apache/arrow-datafusion/pull/8093) (tgujar) +- Combine `Expr::Wildcard` and `Wxpr::QualifiedWildcard`, add `wildcard()` expr fn [#8105](https://github.com/apache/arrow-datafusion/pull/8105) (alamb) +- docs: show creation of DFSchema [#8132](https://github.com/apache/arrow-datafusion/pull/8132) (wjones127) +- feat: support UDAF in substrait producer/consumer [#8119](https://github.com/apache/arrow-datafusion/pull/8119) (waynexia) +- Improve documentation site to make it easier to find communication on Slack/Discord [#8138](https://github.com/apache/arrow-datafusion/pull/8138) (alamb) diff --git a/dev/changelog/34.0.0.md b/dev/changelog/34.0.0.md new file mode 100644 index 000000000000..c5526f60531c --- /dev/null +++ b/dev/changelog/34.0.0.md @@ -0,0 +1,273 @@ + + +## [34.0.0](https://github.com/apache/arrow-datafusion/tree/34.0.0) (2023-12-11) + +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/33.0.0...34.0.0) + +**Breaking changes:** + +- Implement `DISTINCT ON` from Postgres [#7981](https://github.com/apache/arrow-datafusion/pull/7981) (gruuya) +- Encapsulate `EquivalenceClass` into a struct [#8034](https://github.com/apache/arrow-datafusion/pull/8034) (alamb) +- Make fields of `ScalarUDF` , `AggregateUDF` and `WindowUDF` non `pub` [#8079](https://github.com/apache/arrow-datafusion/pull/8079) (alamb) +- Implement StreamTable and StreamTableProvider (#7994) [#8021](https://github.com/apache/arrow-datafusion/pull/8021) (tustvold) +- feat: make FixedSizeList scalar also an ArrayRef [#8221](https://github.com/apache/arrow-datafusion/pull/8221) (wjones127) +- Remove FileWriterMode and ListingTableInsertMode (#7994) [#8017](https://github.com/apache/arrow-datafusion/pull/8017) (tustvold) +- Refactor: Unify `Expr::ScalarFunction` and `Expr::ScalarUDF`, introduce unresolved functions by name [#8258](https://github.com/apache/arrow-datafusion/pull/8258) (2010YOUY01) +- Refactor aggregate function handling [#8358](https://github.com/apache/arrow-datafusion/pull/8358) (Weijun-H) +- Move `PartitionSearchMode` into datafusion_physical_plan, rename to `InputOrderMode` [#8364](https://github.com/apache/arrow-datafusion/pull/8364) (alamb) +- Split `EmptyExec` into `PlaceholderRowExec` [#8446](https://github.com/apache/arrow-datafusion/pull/8446) (razeghi71) + +**Implemented enhancements:** + +- feat: show statistics in explain verbose [#8113](https://github.com/apache/arrow-datafusion/pull/8113) (NGA-TRAN) +- feat:implement postgres style 'overlay' string function [#8117](https://github.com/apache/arrow-datafusion/pull/8117) (Syleechan) +- feat: fill missing values with NULLs while inserting [#8146](https://github.com/apache/arrow-datafusion/pull/8146) (jonahgao) +- feat: to_array_of_size for ScalarValue::FixedSizeList [#8225](https://github.com/apache/arrow-datafusion/pull/8225) (wjones127) +- feat:implement calcite style 'levenshtein' string function [#8168](https://github.com/apache/arrow-datafusion/pull/8168) (Syleechan) +- feat: roundtrip FixedSizeList Scalar to protobuf [#8239](https://github.com/apache/arrow-datafusion/pull/8239) (wjones127) +- feat: impl the basic `string_agg` function [#8148](https://github.com/apache/arrow-datafusion/pull/8148) (haohuaijin) +- feat: support simplifying BinaryExpr with arbitrary guarantees in GuaranteeRewriter [#8256](https://github.com/apache/arrow-datafusion/pull/8256) (wjones127) +- feat: support customizing column default values for inserting [#8283](https://github.com/apache/arrow-datafusion/pull/8283) (jonahgao) +- feat:implement sql style 'substr_index' string function [#8272](https://github.com/apache/arrow-datafusion/pull/8272) (Syleechan) +- feat:implement sql style 'find_in_set' string function [#8328](https://github.com/apache/arrow-datafusion/pull/8328) (Syleechan) +- feat: support `LargeList` in `array_empty` [#8321](https://github.com/apache/arrow-datafusion/pull/8321) (Weijun-H) +- feat: support `LargeList` in `make_array` and `array_length` [#8121](https://github.com/apache/arrow-datafusion/pull/8121) (Weijun-H) +- feat: ScalarValue from String [#8411](https://github.com/apache/arrow-datafusion/pull/8411) (QuenKar) +- feat: support `LargeList` for `array_has`, `array_has_all` and `array_has_any` [#8322](https://github.com/apache/arrow-datafusion/pull/8322) (Weijun-H) +- feat: customize column default values for external tables [#8415](https://github.com/apache/arrow-datafusion/pull/8415) (jonahgao) +- feat: Support `array_sort`(`list_sort`) [#8279](https://github.com/apache/arrow-datafusion/pull/8279) (Asura7969) +- feat: support `InterleaveExecNode` in the proto [#8460](https://github.com/apache/arrow-datafusion/pull/8460) (liukun4515) +- feat: improve string statistics display in datafusion-cli `parquet_metadata` function [#8535](https://github.com/apache/arrow-datafusion/pull/8535) (asimsedhain) + +**Fixed bugs:** + +- fix: Timestamp with timezone not considered `join on` [#8150](https://github.com/apache/arrow-datafusion/pull/8150) (ACking-you) +- fix: wrong result of range function [#8313](https://github.com/apache/arrow-datafusion/pull/8313) (smallzhongfeng) +- fix: make `ntile` work in some corner cases [#8371](https://github.com/apache/arrow-datafusion/pull/8371) (haohuaijin) +- fix: Changed labeler.yml to latest format [#8431](https://github.com/apache/arrow-datafusion/pull/8431) (viirya) +- fix: Literal in `ORDER BY` window definition should not be an ordinal referring to relation column [#8419](https://github.com/apache/arrow-datafusion/pull/8419) (viirya) +- fix: ORDER BY window definition should work on null literal [#8444](https://github.com/apache/arrow-datafusion/pull/8444) (viirya) +- fix: RANGE frame for corner cases with empty ORDER BY clause should be treated as constant sort [#8445](https://github.com/apache/arrow-datafusion/pull/8445) (viirya) +- fix: don't unifies projection if expr is non-trival [#8454](https://github.com/apache/arrow-datafusion/pull/8454) (haohuaijin) +- fix: support uppercase when parsing `Interval` [#8478](https://github.com/apache/arrow-datafusion/pull/8478) (QuenKar) +- fix: incorrect set preserve_partitioning in SortExec [#8485](https://github.com/apache/arrow-datafusion/pull/8485) (haohuaijin) +- fix: Pull stats in `IdentVisitor`/`GraphvizVisitor` only when requested [#8514](https://github.com/apache/arrow-datafusion/pull/8514) (vrongmeal) +- fix: volatile expressions should not be target of common subexpt elimination [#8520](https://github.com/apache/arrow-datafusion/pull/8520) (viirya) + +**Documentation updates:** + +- Library Guide: Add Using the DataFrame API [#8319](https://github.com/apache/arrow-datafusion/pull/8319) (Veeupup) +- Minor: Add installation link to README.md [#8389](https://github.com/apache/arrow-datafusion/pull/8389) (Weijun-H) +- Prepare version 34.0.0 [#8508](https://github.com/apache/arrow-datafusion/pull/8508) (andygrove) + +**Merged pull requests:** + +- Fix typo in partitioning.rs [#8134](https://github.com/apache/arrow-datafusion/pull/8134) (lewiszlw) +- Implement `DISTINCT ON` from Postgres [#7981](https://github.com/apache/arrow-datafusion/pull/7981) (gruuya) +- Prepare 33.0.0-rc2 [#8144](https://github.com/apache/arrow-datafusion/pull/8144) (andygrove) +- Avoid concat in `array_append` [#8137](https://github.com/apache/arrow-datafusion/pull/8137) (jayzhan211) +- Replace macro with function for array_remove [#8106](https://github.com/apache/arrow-datafusion/pull/8106) (jayzhan211) +- Implement `array_union` [#7897](https://github.com/apache/arrow-datafusion/pull/7897) (edmondop) +- Minor: Document `ExecutionPlan::equivalence_properties` more thoroughly [#8128](https://github.com/apache/arrow-datafusion/pull/8128) (alamb) +- feat: show statistics in explain verbose [#8113](https://github.com/apache/arrow-datafusion/pull/8113) (NGA-TRAN) +- feat:implement postgres style 'overlay' string function [#8117](https://github.com/apache/arrow-datafusion/pull/8117) (Syleechan) +- Minor: Encapsulate `LeftJoinData` into a struct (rather than anonymous enum) and add comments [#8153](https://github.com/apache/arrow-datafusion/pull/8153) (alamb) +- Update sqllogictest requirement from 0.18.0 to 0.19.0 [#8163](https://github.com/apache/arrow-datafusion/pull/8163) (dependabot[bot]) +- feat: fill missing values with NULLs while inserting [#8146](https://github.com/apache/arrow-datafusion/pull/8146) (jonahgao) +- Introduce return type for aggregate sum [#8141](https://github.com/apache/arrow-datafusion/pull/8141) (jayzhan211) +- implement range/generate_series func [#8140](https://github.com/apache/arrow-datafusion/pull/8140) (Veeupup) +- Encapsulate `EquivalenceClass` into a struct [#8034](https://github.com/apache/arrow-datafusion/pull/8034) (alamb) +- Revert "Minor: remove unnecessary projection in `single_distinct_to_g… [#8176](https://github.com/apache/arrow-datafusion/pull/8176) (NGA-TRAN) +- Preserve all of the valid orderings during merging. [#8169](https://github.com/apache/arrow-datafusion/pull/8169) (mustafasrepo) +- Make fields of `ScalarUDF` , `AggregateUDF` and `WindowUDF` non `pub` [#8079](https://github.com/apache/arrow-datafusion/pull/8079) (alamb) +- Fix logical conflicts [#8187](https://github.com/apache/arrow-datafusion/pull/8187) (tustvold) +- Minor: Update JoinHashMap comment example to make it clearer [#8154](https://github.com/apache/arrow-datafusion/pull/8154) (alamb) +- Implement StreamTable and StreamTableProvider (#7994) [#8021](https://github.com/apache/arrow-datafusion/pull/8021) (tustvold) +- [MINOR]: Remove unused Results [#8189](https://github.com/apache/arrow-datafusion/pull/8189) (mustafasrepo) +- Minor: clean up the code based on clippy [#8179](https://github.com/apache/arrow-datafusion/pull/8179) (Weijun-H) +- Minor: simplify filter statistics code [#8174](https://github.com/apache/arrow-datafusion/pull/8174) (alamb) +- Replace macro with function for `array_position` and `array_positions` [#8170](https://github.com/apache/arrow-datafusion/pull/8170) (jayzhan211) +- Add Library Guide for User Defined Functions: Window/Aggregate [#8171](https://github.com/apache/arrow-datafusion/pull/8171) (Veeupup) +- Add more stream docs [#8192](https://github.com/apache/arrow-datafusion/pull/8192) (tustvold) +- Implement func `array_pop_front` [#8142](https://github.com/apache/arrow-datafusion/pull/8142) (Veeupup) +- Moving arrow_files SQL tests to sqllogictest [#8217](https://github.com/apache/arrow-datafusion/pull/8217) (edmondop) +- fix regression in the use of name in ProjectionPushdown [#8219](https://github.com/apache/arrow-datafusion/pull/8219) (alamb) +- [MINOR]: Fix column indices in the planning tests [#8191](https://github.com/apache/arrow-datafusion/pull/8191) (mustafasrepo) +- Remove unnecessary reassignment [#8232](https://github.com/apache/arrow-datafusion/pull/8232) (qrilka) +- Update itertools requirement from 0.11 to 0.12 [#8233](https://github.com/apache/arrow-datafusion/pull/8233) (crepererum) +- Port tests in subqueries.rs to sqllogictest [#8231](https://github.com/apache/arrow-datafusion/pull/8231) (PsiACE) +- feat: make FixedSizeList scalar also an ArrayRef [#8221](https://github.com/apache/arrow-datafusion/pull/8221) (wjones127) +- Add versions to datafusion dependencies [#8238](https://github.com/apache/arrow-datafusion/pull/8238) (andygrove) +- feat: to_array_of_size for ScalarValue::FixedSizeList [#8225](https://github.com/apache/arrow-datafusion/pull/8225) (wjones127) +- feat:implement calcite style 'levenshtein' string function [#8168](https://github.com/apache/arrow-datafusion/pull/8168) (Syleechan) +- feat: roundtrip FixedSizeList Scalar to protobuf [#8239](https://github.com/apache/arrow-datafusion/pull/8239) (wjones127) +- Update prost-build requirement from =0.12.1 to =0.12.2 [#8244](https://github.com/apache/arrow-datafusion/pull/8244) (dependabot[bot]) +- Minor: Port tests in `displayable.rs` to sqllogictest [#8246](https://github.com/apache/arrow-datafusion/pull/8246) (Weijun-H) +- Minor: add `with_estimated_selectivity ` to Precision [#8177](https://github.com/apache/arrow-datafusion/pull/8177) (alamb) +- fix: Timestamp with timezone not considered `join on` [#8150](https://github.com/apache/arrow-datafusion/pull/8150) (ACking-you) +- Replace macro in array_array to remove duplicate codes [#8252](https://github.com/apache/arrow-datafusion/pull/8252) (Veeupup) +- Port tests in projection.rs to sqllogictest [#8240](https://github.com/apache/arrow-datafusion/pull/8240) (PsiACE) +- Introduce `array_except` function [#8135](https://github.com/apache/arrow-datafusion/pull/8135) (jayzhan211) +- Port tests in `describe.rs` to sqllogictest [#8242](https://github.com/apache/arrow-datafusion/pull/8242) (Asura7969) +- Remove FileWriterMode and ListingTableInsertMode (#7994) [#8017](https://github.com/apache/arrow-datafusion/pull/8017) (tustvold) +- Minor: clean up the code based on Clippy [#8257](https://github.com/apache/arrow-datafusion/pull/8257) (Weijun-H) +- Update arrow 49.0.0 and object_store 0.8.0 [#8029](https://github.com/apache/arrow-datafusion/pull/8029) (tustvold) +- feat: impl the basic `string_agg` function [#8148](https://github.com/apache/arrow-datafusion/pull/8148) (haohuaijin) +- Minor: Make schema of grouping set columns nullable [#8248](https://github.com/apache/arrow-datafusion/pull/8248) (markusa380) +- feat: support simplifying BinaryExpr with arbitrary guarantees in GuaranteeRewriter [#8256](https://github.com/apache/arrow-datafusion/pull/8256) (wjones127) +- Making stream joins extensible: A new Trait implementation for SHJ [#8234](https://github.com/apache/arrow-datafusion/pull/8234) (metesynnada) +- Don't Canonicalize Filesystem Paths in ListingTableUrl / support new external tables for files that do not (yet) exist [#8014](https://github.com/apache/arrow-datafusion/pull/8014) (tustvold) +- Minor: Add sql level test for inserting into non-existent directory [#8278](https://github.com/apache/arrow-datafusion/pull/8278) (alamb) +- Replace `array_has/array_has_all/array_has_any` macro to remove duplicate code [#8263](https://github.com/apache/arrow-datafusion/pull/8263) (Veeupup) +- Fix bug in field level metadata matching code [#8286](https://github.com/apache/arrow-datafusion/pull/8286) (alamb) +- Refactor Interval Arithmetic Updates [#8276](https://github.com/apache/arrow-datafusion/pull/8276) (berkaysynnada) +- [MINOR]: Remove unecessary orderings from the final plan [#8289](https://github.com/apache/arrow-datafusion/pull/8289) (mustafasrepo) +- consistent logical & physical `NTILE` return types [#8270](https://github.com/apache/arrow-datafusion/pull/8270) (korowa) +- make `array_union`/`array_except`/`array_intersect` handle empty/null arrays rightly [#8269](https://github.com/apache/arrow-datafusion/pull/8269) (Veeupup) +- improve file path validation when reading parquet [#8267](https://github.com/apache/arrow-datafusion/pull/8267) (Weijun-H) +- [Benchmarks] Make `partitions` default to number of cores instead of 2 [#8292](https://github.com/apache/arrow-datafusion/pull/8292) (andygrove) +- Update prost-build requirement from =0.12.2 to =0.12.3 [#8298](https://github.com/apache/arrow-datafusion/pull/8298) (dependabot[bot]) +- Fix Display for List [#8261](https://github.com/apache/arrow-datafusion/pull/8261) (jayzhan211) +- feat: support customizing column default values for inserting [#8283](https://github.com/apache/arrow-datafusion/pull/8283) (jonahgao) +- support `LargeList` for `arrow_cast`, support `ScalarValue::LargeList` [#8290](https://github.com/apache/arrow-datafusion/pull/8290) (Weijun-H) +- Minor: remove useless clone based on Clippy [#8300](https://github.com/apache/arrow-datafusion/pull/8300) (Weijun-H) +- Calculate ordering equivalence for expressions (rather than just columns) [#8281](https://github.com/apache/arrow-datafusion/pull/8281) (mustafasrepo) +- Fix sqllogictests link in contributor-guide/index.md [#8314](https://github.com/apache/arrow-datafusion/pull/8314) (qrilka) +- Refactor: Unify `Expr::ScalarFunction` and `Expr::ScalarUDF`, introduce unresolved functions by name [#8258](https://github.com/apache/arrow-datafusion/pull/8258) (2010YOUY01) +- Support no distinct aggregate sum/min/max in `single_distinct_to_group_by` rule [#8266](https://github.com/apache/arrow-datafusion/pull/8266) (haohuaijin) +- feat:implement sql style 'substr_index' string function [#8272](https://github.com/apache/arrow-datafusion/pull/8272) (Syleechan) +- Fixing issues with for timestamp literals [#8193](https://github.com/apache/arrow-datafusion/pull/8193) (comphead) +- Projection Pushdown over StreamingTableExec [#8299](https://github.com/apache/arrow-datafusion/pull/8299) (berkaysynnada) +- minor: fix documentation [#8323](https://github.com/apache/arrow-datafusion/pull/8323) (comphead) +- fix: wrong result of range function [#8313](https://github.com/apache/arrow-datafusion/pull/8313) (smallzhongfeng) +- Minor: rename parquet.rs to parquet/mod.rs [#8301](https://github.com/apache/arrow-datafusion/pull/8301) (alamb) +- refactor: output ordering [#8304](https://github.com/apache/arrow-datafusion/pull/8304) (QuenKar) +- Update substrait requirement from 0.19.0 to 0.20.0 [#8339](https://github.com/apache/arrow-datafusion/pull/8339) (dependabot[bot]) +- Port tests in `aggregates.rs` to sqllogictest [#8316](https://github.com/apache/arrow-datafusion/pull/8316) (edmondop) +- Library Guide: Add Using the DataFrame API [#8319](https://github.com/apache/arrow-datafusion/pull/8319) (Veeupup) +- Port tests in limit.rs to sqllogictest [#8315](https://github.com/apache/arrow-datafusion/pull/8315) (zhangxffff) +- move array function unit_tests to sqllogictest [#8332](https://github.com/apache/arrow-datafusion/pull/8332) (Veeupup) +- NTH_VALUE reverse support [#8327](https://github.com/apache/arrow-datafusion/pull/8327) (mustafasrepo) +- Optimize Projections during Logical Plan [#8340](https://github.com/apache/arrow-datafusion/pull/8340) (mustafasrepo) +- [MINOR]: Move merge projections tests to under optimize projections [#8352](https://github.com/apache/arrow-datafusion/pull/8352) (mustafasrepo) +- Add `quote` and `escape` attributes to create csv external table [#8351](https://github.com/apache/arrow-datafusion/pull/8351) (Asura7969) +- Minor: Add DataFrame test [#8341](https://github.com/apache/arrow-datafusion/pull/8341) (alamb) +- Minor: clean up the code based on Clippy [#8359](https://github.com/apache/arrow-datafusion/pull/8359) (Weijun-H) +- Minor: Make it easier to work with Expr::ScalarFunction [#8350](https://github.com/apache/arrow-datafusion/pull/8350) (alamb) +- Minor: Move some datafusion-optimizer::utils down to datafusion-expr::utils [#8354](https://github.com/apache/arrow-datafusion/pull/8354) (Jesse-Bakker) +- Minor: Make `BuiltInScalarFunction::alias` a method [#8349](https://github.com/apache/arrow-datafusion/pull/8349) (alamb) +- Extract parquet statistics to its own module, add tests [#8294](https://github.com/apache/arrow-datafusion/pull/8294) (alamb) +- feat:implement sql style 'find_in_set' string function [#8328](https://github.com/apache/arrow-datafusion/pull/8328) (Syleechan) +- Support LargeUtf8 to Temporal Coercion [#8357](https://github.com/apache/arrow-datafusion/pull/8357) (jayzhan211) +- Refactor aggregate function handling [#8358](https://github.com/apache/arrow-datafusion/pull/8358) (Weijun-H) +- Implement Aliases for ScalarUDF [#8360](https://github.com/apache/arrow-datafusion/pull/8360) (Veeupup) +- Minor: Remove unnecessary name field in `ScalarFunctionDefintion` [#8365](https://github.com/apache/arrow-datafusion/pull/8365) (alamb) +- feat: support `LargeList` in `array_empty` [#8321](https://github.com/apache/arrow-datafusion/pull/8321) (Weijun-H) +- Double type argument for to_timestamp function [#8159](https://github.com/apache/arrow-datafusion/pull/8159) (spaydar) +- Support User Defined Table Function [#8306](https://github.com/apache/arrow-datafusion/pull/8306) (Veeupup) +- Document timestamp input limits [#8369](https://github.com/apache/arrow-datafusion/pull/8369) (comphead) +- fix: make `ntile` work in some corner cases [#8371](https://github.com/apache/arrow-datafusion/pull/8371) (haohuaijin) +- Minor: Refactor array_union function to use a generic union_arrays function [#8381](https://github.com/apache/arrow-datafusion/pull/8381) (Weijun-H) +- Minor: Refactor function argument handling in `ScalarFunctionDefinition` [#8387](https://github.com/apache/arrow-datafusion/pull/8387) (Weijun-H) +- Materialize dictionaries in group keys [#8291](https://github.com/apache/arrow-datafusion/pull/8291) (qrilka) +- Rewrite `array_ndims` to fix List(Null) handling [#8320](https://github.com/apache/arrow-datafusion/pull/8320) (jayzhan211) +- Docs: Improve the documentation on `ScalarValue` [#8378](https://github.com/apache/arrow-datafusion/pull/8378) (alamb) +- Avoid concat for `array_replace` [#8337](https://github.com/apache/arrow-datafusion/pull/8337) (jayzhan211) +- add a summary table to benchmark compare output [#8399](https://github.com/apache/arrow-datafusion/pull/8399) (razeghi71) +- Refactors on TreeNode Implementations [#8395](https://github.com/apache/arrow-datafusion/pull/8395) (berkaysynnada) +- feat: support `LargeList` in `make_array` and `array_length` [#8121](https://github.com/apache/arrow-datafusion/pull/8121) (Weijun-H) +- remove `unalias` TableScan filters when create Physical Filter [#8404](https://github.com/apache/arrow-datafusion/pull/8404) (jackwener) +- Update custom-table-providers.md [#8409](https://github.com/apache/arrow-datafusion/pull/8409) (nickpoorman) +- fix transforming `LogicalPlan::Explain` use `TreeNode::transform` fails [#8400](https://github.com/apache/arrow-datafusion/pull/8400) (haohuaijin) +- Docs: Fix `array_except` documentation example error [#8407](https://github.com/apache/arrow-datafusion/pull/8407) (Asura7969) +- Support named query parameters [#8384](https://github.com/apache/arrow-datafusion/pull/8384) (Asura7969) +- Minor: Add installation link to README.md [#8389](https://github.com/apache/arrow-datafusion/pull/8389) (Weijun-H) +- Update code comment for the cases of regularized RANGE frame and add tests for ORDER BY cases with RANGE frame [#8410](https://github.com/apache/arrow-datafusion/pull/8410) (viirya) +- Minor: Add example with parameters to LogicalPlan [#8418](https://github.com/apache/arrow-datafusion/pull/8418) (alamb) +- Minor: Improve `PruningPredicate` documentation [#8394](https://github.com/apache/arrow-datafusion/pull/8394) (alamb) +- feat: ScalarValue from String [#8411](https://github.com/apache/arrow-datafusion/pull/8411) (QuenKar) +- Bump actions/labeler from 4.3.0 to 5.0.0 [#8422](https://github.com/apache/arrow-datafusion/pull/8422) (dependabot[bot]) +- Update sqlparser requirement from 0.39.0 to 0.40.0 [#8338](https://github.com/apache/arrow-datafusion/pull/8338) (dependabot[bot]) +- feat: support `LargeList` for `array_has`, `array_has_all` and `array_has_any` [#8322](https://github.com/apache/arrow-datafusion/pull/8322) (Weijun-H) +- Union `schema` can't be a subset of the child schema [#8408](https://github.com/apache/arrow-datafusion/pull/8408) (jackwener) +- Move `PartitionSearchMode` into datafusion_physical_plan, rename to `InputOrderMode` [#8364](https://github.com/apache/arrow-datafusion/pull/8364) (alamb) +- Make filter selectivity for statistics configurable [#8243](https://github.com/apache/arrow-datafusion/pull/8243) (edmondop) +- fix: Changed labeler.yml to latest format [#8431](https://github.com/apache/arrow-datafusion/pull/8431) (viirya) +- Minor: Use `ScalarValue::from` impl for strings [#8429](https://github.com/apache/arrow-datafusion/pull/8429) (alamb) +- Support crossjoin in substrait. [#8427](https://github.com/apache/arrow-datafusion/pull/8427) (my-vegetable-has-exploded) +- Fix ambiguous reference when aliasing in combination with `ORDER BY` [#8425](https://github.com/apache/arrow-datafusion/pull/8425) (Asura7969) +- Minor: convert marcro `list-slice` and `slice` to function [#8424](https://github.com/apache/arrow-datafusion/pull/8424) (Weijun-H) +- Remove macro in iter_to_array for List [#8414](https://github.com/apache/arrow-datafusion/pull/8414) (jayzhan211) +- fix: Literal in `ORDER BY` window definition should not be an ordinal referring to relation column [#8419](https://github.com/apache/arrow-datafusion/pull/8419) (viirya) +- feat: customize column default values for external tables [#8415](https://github.com/apache/arrow-datafusion/pull/8415) (jonahgao) +- feat: Support `array_sort`(`list_sort`) [#8279](https://github.com/apache/arrow-datafusion/pull/8279) (Asura7969) +- Bugfix: Remove df-cli specific SQL statment options before executing with DataFusion [#8426](https://github.com/apache/arrow-datafusion/pull/8426) (devinjdangelo) +- Detect when filters on unique constraints make subqueries scalar [#8312](https://github.com/apache/arrow-datafusion/pull/8312) (Jesse-Bakker) +- Add alias check to optimize projections merge [#8438](https://github.com/apache/arrow-datafusion/pull/8438) (mustafasrepo) +- Fix PartialOrd for ScalarValue::List/FixSizeList/LargeList [#8253](https://github.com/apache/arrow-datafusion/pull/8253) (jayzhan211) +- Support parquet_metadata for datafusion-cli [#8413](https://github.com/apache/arrow-datafusion/pull/8413) (Veeupup) +- Fix bug in optimizing a nested count [#8459](https://github.com/apache/arrow-datafusion/pull/8459) (Dandandan) +- Bump actions/setup-python from 4 to 5 [#8449](https://github.com/apache/arrow-datafusion/pull/8449) (dependabot[bot]) +- fix: ORDER BY window definition should work on null literal [#8444](https://github.com/apache/arrow-datafusion/pull/8444) (viirya) +- flx clippy warnings [#8455](https://github.com/apache/arrow-datafusion/pull/8455) (waynexia) +- fix: RANGE frame for corner cases with empty ORDER BY clause should be treated as constant sort [#8445](https://github.com/apache/arrow-datafusion/pull/8445) (viirya) +- Preserve `dict_id` on `Field` during serde roundtrip [#8457](https://github.com/apache/arrow-datafusion/pull/8457) (avantgardnerio) +- feat: support `InterleaveExecNode` in the proto [#8460](https://github.com/apache/arrow-datafusion/pull/8460) (liukun4515) +- [BUG FIX]: Proper Empty Batch handling in window execution [#8466](https://github.com/apache/arrow-datafusion/pull/8466) (mustafasrepo) +- Minor: update `cast` [#8458](https://github.com/apache/arrow-datafusion/pull/8458) (Weijun-H) +- fix: don't unifies projection if expr is non-trival [#8454](https://github.com/apache/arrow-datafusion/pull/8454) (haohuaijin) +- Minor: Add new bloom filter predicate tests [#8433](https://github.com/apache/arrow-datafusion/pull/8433) (alamb) +- Add PRIMARY KEY Aggregate support to dataframe API [#8356](https://github.com/apache/arrow-datafusion/pull/8356) (mustafasrepo) +- Minor: refactor `data_trunc` to reduce duplicated code [#8430](https://github.com/apache/arrow-datafusion/pull/8430) (Weijun-H) +- Support array_distinct function. [#8268](https://github.com/apache/arrow-datafusion/pull/8268) (my-vegetable-has-exploded) +- Add primary key support to stream table [#8467](https://github.com/apache/arrow-datafusion/pull/8467) (mustafasrepo) +- Add `evaluate_demo` and `range_analysis_demo` to Expr examples [#8377](https://github.com/apache/arrow-datafusion/pull/8377) (alamb) +- Minor: fix function name typo [#8473](https://github.com/apache/arrow-datafusion/pull/8473) (Weijun-H) +- Minor: Fix comment typo in table.rs: s/indentical/identical/ [#8469](https://github.com/apache/arrow-datafusion/pull/8469) (KeunwooLee-at) +- Remove `define_array_slice` and reuse `array_slice` for `array_pop_front/back` [#8401](https://github.com/apache/arrow-datafusion/pull/8401) (jayzhan211) +- Minor: refactor `trim` to clean up duplicated code [#8434](https://github.com/apache/arrow-datafusion/pull/8434) (Weijun-H) +- Split `EmptyExec` into `PlaceholderRowExec` [#8446](https://github.com/apache/arrow-datafusion/pull/8446) (razeghi71) +- Enable non-uniform field type for structs created in DataFusion [#8463](https://github.com/apache/arrow-datafusion/pull/8463) (dlovell) +- Minor: Add multi ordering test for array agg order [#8439](https://github.com/apache/arrow-datafusion/pull/8439) (jayzhan211) +- Sort filenames when reading parquet to ensure consistent schema [#6629](https://github.com/apache/arrow-datafusion/pull/6629) (thomas-k-cameron) +- Minor: Improve comments in EnforceDistribution tests [#8474](https://github.com/apache/arrow-datafusion/pull/8474) (alamb) +- fix: support uppercase when parsing `Interval` [#8478](https://github.com/apache/arrow-datafusion/pull/8478) (QuenKar) +- Better Equivalence (ordering and exact equivalence) Propagation through ProjectionExec [#8484](https://github.com/apache/arrow-datafusion/pull/8484) (mustafasrepo) +- Add `today` alias for `current_date` [#8423](https://github.com/apache/arrow-datafusion/pull/8423) (smallzhongfeng) +- Minor: remove useless clone in `array_expression` [#8495](https://github.com/apache/arrow-datafusion/pull/8495) (Weijun-H) +- fix: incorrect set preserve_partitioning in SortExec [#8485](https://github.com/apache/arrow-datafusion/pull/8485) (haohuaijin) +- Explicitly mark parquet for tests in datafusion-common [#8497](https://github.com/apache/arrow-datafusion/pull/8497) (Dennis40816) +- Minor/Doc: Clarify DataFrame::write_table Documentation [#8519](https://github.com/apache/arrow-datafusion/pull/8519) (devinjdangelo) +- fix: Pull stats in `IdentVisitor`/`GraphvizVisitor` only when requested [#8514](https://github.com/apache/arrow-datafusion/pull/8514) (vrongmeal) +- Change display of RepartitionExec from SortPreservingRepartitionExec to RepartitionExec preserve_order=true [#8521](https://github.com/apache/arrow-datafusion/pull/8521) (JacobOgle) +- Fix `DataFrame::cache` errors with `Plan("Mismatch between schema and batches")` [#8510](https://github.com/apache/arrow-datafusion/pull/8510) (Asura7969) +- Minor: update pbjson_dependency [#8470](https://github.com/apache/arrow-datafusion/pull/8470) (alamb) +- Minor: Update prost-derive dependency [#8471](https://github.com/apache/arrow-datafusion/pull/8471) (alamb) +- Minor/Doc: Add DataFrame::write_table to DataFrame user guide [#8527](https://github.com/apache/arrow-datafusion/pull/8527) (devinjdangelo) +- Minor: Add repartition_file.slt end to end test for repartitioning files, and supporting tweaks [#8505](https://github.com/apache/arrow-datafusion/pull/8505) (alamb) +- Prepare version 34.0.0 [#8508](https://github.com/apache/arrow-datafusion/pull/8508) (andygrove) +- refactor: use ExprBuilder to consume substrait expr and use macro to generate error [#8515](https://github.com/apache/arrow-datafusion/pull/8515) (waynexia) +- [MINOR]: Make some slt tests deterministic [#8525](https://github.com/apache/arrow-datafusion/pull/8525) (mustafasrepo) +- fix: volatile expressions should not be target of common subexpt elimination [#8520](https://github.com/apache/arrow-datafusion/pull/8520) (viirya) +- Minor: Add LakeSoul to the list of Known Users [#8536](https://github.com/apache/arrow-datafusion/pull/8536) (xuchen-plus) +- Fix regression with Incorrect results when reading parquet files with different schemas and statistics [#8533](https://github.com/apache/arrow-datafusion/pull/8533) (alamb) +- feat: improve string statistics display in datafusion-cli `parquet_metadata` function [#8535](https://github.com/apache/arrow-datafusion/pull/8535) (asimsedhain) +- Defer file creation to write [#8539](https://github.com/apache/arrow-datafusion/pull/8539) (tustvold) +- Minor: Improve error handling in sqllogictest runner [#8544](https://github.com/apache/arrow-datafusion/pull/8544) (alamb) diff --git a/dev/release/README.md b/dev/release/README.md index b44259ad560b..53487678aa69 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -82,7 +82,7 @@ You will need a GitHub Personal Access Token for the following steps. Follow [these instructions](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token) to generate one if you do not already have one. -The changelog is generated using a Python script. There is a depency on `PyGitHub`, which can be installed using pip: +The changelog is generated using a Python script. There is a dependency on `PyGitHub`, which can be installed using pip: ```bash pip3 install PyGitHub diff --git a/dev/release/generate-changelog.py b/dev/release/generate-changelog.py index ff9e8d4754b2..f419bdb3a1ac 100755 --- a/dev/release/generate-changelog.py +++ b/dev/release/generate-changelog.py @@ -57,6 +57,7 @@ def generate_changelog(repo, repo_name, tag1, tag2): bugs = [] docs = [] enhancements = [] + performance = [] # categorize the pull requests based on GitHub labels print("Categorizing pull requests", file=sys.stderr) @@ -79,6 +80,8 @@ def generate_changelog(repo, repo_name, tag1, tag2): breaking.append((pull, commit)) elif 'bug' in labels or cc_type == 'fix': bugs.append((pull, commit)) + elif 'performance' in labels or cc_type == 'perf': + performance.append((pull, commit)) elif 'enhancement' in labels or cc_type == 'feat': enhancements.append((pull, commit)) elif 'documentation' in labels or cc_type == 'docs': @@ -87,6 +90,7 @@ def generate_changelog(repo, repo_name, tag1, tag2): # produce the changelog content print("Generating changelog content", file=sys.stderr) print_pulls(repo_name, "Breaking changes", breaking) + print_pulls(repo_name, "Performance related", performance) print_pulls(repo_name, "Implemented enhancements", enhancements) print_pulls(repo_name, "Fixed bugs", bugs) print_pulls(repo_name, "Documentation updates", docs) diff --git a/dev/update_datafusion_versions.py b/dev/update_datafusion_versions.py index d9433915f7e2..19701b813671 100755 --- a/dev/update_datafusion_versions.py +++ b/dev/update_datafusion_versions.py @@ -35,12 +35,15 @@ 'datafusion-execution': 'datafusion/execution/Cargo.toml', 'datafusion-optimizer': 'datafusion/optimizer/Cargo.toml', 'datafusion-physical-expr': 'datafusion/physical-expr/Cargo.toml', + 'datafusion-physical-plan': 'datafusion/physical-plan/Cargo.toml', 'datafusion-proto': 'datafusion/proto/Cargo.toml', 'datafusion-substrait': 'datafusion/substrait/Cargo.toml', 'datafusion-sql': 'datafusion/sql/Cargo.toml', 'datafusion-sqllogictest': 'datafusion/sqllogictest/Cargo.toml', + 'datafusion-wasmtest': 'datafusion/wasmtest/Cargo.toml', 'datafusion-benchmarks': 'benchmarks/Cargo.toml', 'datafusion-examples': 'datafusion-examples/Cargo.toml', + 'datafusion-docs': 'docs/Cargo.toml', } def update_workspace_version(new_version: str): diff --git a/docs/Cargo.toml b/docs/Cargo.toml new file mode 100644 index 000000000000..813335e30f77 --- /dev/null +++ b/docs/Cargo.toml @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-docs-tests" +description = "DataFusion Documentation Tests" +publish = false +version = { workspace = true } +edition = { workspace = true } +readme = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = "1.70" + +[dependencies] +datafusion = { path = "../datafusion/core", version = "34.0.0", default-features = false } diff --git a/docs/logos/DataFUSION-Logo-Dark.svg b/docs/logos/DataFUSION-Logo-Dark.svg new file mode 100644 index 000000000000..e16f244430e6 --- /dev/null +++ b/docs/logos/DataFUSION-Logo-Dark.svg @@ -0,0 +1 @@ +DataFUSION-Logo-Dark \ No newline at end of file diff --git a/docs/logos/DataFUSION-Logo-Dark@2x.png b/docs/logos/DataFUSION-Logo-Dark@2x.png new file mode 100644 index 000000000000..cc60f12a0e4f Binary files /dev/null and b/docs/logos/DataFUSION-Logo-Dark@2x.png differ diff --git a/docs/logos/DataFUSION-Logo-Dark@4x.png b/docs/logos/DataFUSION-Logo-Dark@4x.png new file mode 100644 index 000000000000..0503c216ac84 Binary files /dev/null and b/docs/logos/DataFUSION-Logo-Dark@4x.png differ diff --git a/docs/logos/DataFUSION-Logo-Light.svg b/docs/logos/DataFUSION-Logo-Light.svg new file mode 100644 index 000000000000..b3bef2193dde --- /dev/null +++ b/docs/logos/DataFUSION-Logo-Light.svg @@ -0,0 +1 @@ +DataFUSION-Logo-Light \ No newline at end of file diff --git a/docs/logos/DataFUSION-Logo-Light@2x.png b/docs/logos/DataFUSION-Logo-Light@2x.png new file mode 100644 index 000000000000..8992213b0e60 Binary files /dev/null and b/docs/logos/DataFUSION-Logo-Light@2x.png differ diff --git a/docs/logos/DataFUSION-Logo-Light@4x.png b/docs/logos/DataFUSION-Logo-Light@4x.png new file mode 100644 index 000000000000..bd329ca21956 Binary files /dev/null and b/docs/logos/DataFUSION-Logo-Light@4x.png differ diff --git a/docs/logos/DataFusion-LogoAndColorPaletteExploration_v01.pdf b/docs/logos/DataFusion-LogoAndColorPaletteExploration_v01.pdf new file mode 100644 index 000000000000..4594c50f9044 Binary files /dev/null and b/docs/logos/DataFusion-LogoAndColorPaletteExploration_v01.pdf differ diff --git a/docs/source/_static/theme_overrides.css b/docs/source/_static/theme_overrides.css index 838eab067afc..3b1b86daac6a 100644 --- a/docs/source/_static/theme_overrides.css +++ b/docs/source/_static/theme_overrides.css @@ -49,7 +49,7 @@ code { } /* This is the bootstrap CSS style for "table-striped". Since the theme does -not yet provide an easy way to configure this globaly, it easier to simply +not yet provide an easy way to configure this globally, it easier to simply include this snippet here than updating each table in all rst files to add ":class: table-striped" */ @@ -59,7 +59,7 @@ add ":class: table-striped" */ /* Limit the max height of the sidebar navigation section. Because in our -custimized template, there is more content above the navigation, i.e. +customized template, there is more content above the navigation, i.e. larger logo: if we don't decrease the max-height, it will overlap with the footer. Details: 8rem for search box etc*/ diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index a9d0f30bcf8e..9f7880049856 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -3,3 +3,24 @@ {# Silence the navbar #} {% block docs_navbar %} {% endblock %} + + +{% block footer %} + +
+
+ {% for footer_item in theme_footer_items %} + + {% endfor %} + +
+
+ +{% endblock %} diff --git a/docs/source/conf.py b/docs/source/conf.py index 9aa84d49bc0a..becece330d1a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -33,9 +33,9 @@ # -- Project information ----------------------------------------------------- -project = 'Arrow DataFusion' -copyright = '2023, Apache Software Foundation' -author = 'Arrow DataFusion Authors' +project = 'Apache Arrow DataFusion' +copyright = '2019-2024, Apache Software Foundation' +author = 'Apache Software Foundation' # -- General configuration --------------------------------------------------- @@ -118,4 +118,4 @@ myst_heading_anchors = 3 # enable nice rendering of checkboxes for the task lists -myst_enable_extensions = [ "tasklist"] +myst_enable_extensions = ["colon_fence", "deflist", "tasklist"] diff --git a/docs/source/contributor-guide/communication.md b/docs/source/contributor-guide/communication.md index 11e0e4e0f0ea..8678aa534baf 100644 --- a/docs/source/contributor-guide/communication.md +++ b/docs/source/contributor-guide/communication.md @@ -26,15 +26,25 @@ All participation in the Apache Arrow DataFusion project is governed by the Apache Software Foundation's [code of conduct](https://www.apache.org/foundation/policies/conduct.html). +## GitHub + The vast majority of communication occurs in the open on our -[github repository](https://github.com/apache/arrow-datafusion). +[github repository](https://github.com/apache/arrow-datafusion) in the form of tickets, issues, discussions, and Pull Requests. + +## Slack and Discord -## Questions? +We use the Slack and Discord platforms for informal discussions and coordination. These are great places to +meet other contributors and get guidance on where to contribute. It is important to note that any technical designs and +decisions are made fully in the open, on GitHub. -### Mailing list +Most of us use the `#arrow-datafusion` and `#arrow-rust` channels in the [ASF Slack workspace](https://s.apache.org/slack-invite) . +Unfortunately, due to spammers, the ASF Slack workspace requires an invitation to join. To get an invitation, +request one in the `Arrow Rust` channel of the [Arrow Rust Discord server](https://discord.gg/Qw5gKqHxUM). -We use arrow.apache.org's `dev@` mailing list for project management, release -coordination and design discussions +## Mailing list + +We also use arrow.apache.org's `dev@` mailing list for release coordination and occasional design discussions. Other +than the the release process, most DataFusion mailing list traffic will link to a GitHub issue or PR for discussion. ([subscribe](mailto:dev-subscribe@arrow.apache.org), [unsubscribe](mailto:dev-unsubscribe@arrow.apache.org), [archives](https://lists.apache.org/list.html?dev@arrow.apache.org)). @@ -42,33 +52,3 @@ coordination and design discussions When emailing the dev list, please make sure to prefix the subject line with a `[DataFusion]` tag, e.g. `"[DataFusion] New API for remote data sources"`, so that the appropriate people in the Apache Arrow community notice the message. - -### Slack and Discord - -We use the official [ASF](https://s.apache.org/slack-invite) Slack workspace -for informal discussions and coordination. This is a great place to meet other -contributors and get guidance on where to contribute. Join us in the -`#arrow-rust` channel. - -We also have a backup Arrow Rust Discord -server ([invite link](https://discord.gg/Qw5gKqHxUM)) in case you are not able -to join the Slack workspace. If you need an invite to the Slack workspace, you -can also ask for one in our Discord server. - -### Sync up video calls - -We have biweekly sync calls every other Thursdays at both 04:00 UTC -and 16:00 UTC (starting September 30, 2021) depending on if there are -items on the agenda to discuss and someone being willing to host. - -Please see the [agenda](https://docs.google.com/document/d/1atCVnoff5SR4eM4Lwf2M1BBJTY6g3_HUNR6qswYJW_U/edit) -for the video call link, add topics and to see what others plan to discuss. - -The goals of these calls are: - -1. Help "put a face to the name" of some of other contributors we are working with -2. Discuss / synchronize on the goals and major initiatives from different stakeholders to identify areas where more alignment is needed - -No decisions are made on the call and anything of substance will be discussed on the mailing list or in github issues / google docs. - -We will send a summary of all sync ups to the dev@arrow.apache.org mailing list. diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index e42ab0dee07a..8d69ade83d72 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -151,7 +151,7 @@ Tests for code in an individual module are defined in the same source file with ### sqllogictests Tests -DataFusion's SQL implementation is tested using [sqllogictest](https://github.com/apache/arrow-datafusion/tree/main/datafusion/core/tests/sqllogictests) which are run like any other Rust test using `cargo test --test sqllogictests`. +DataFusion's SQL implementation is tested using [sqllogictest](https://github.com/apache/arrow-datafusion/tree/main/datafusion/sqllogictest) which are run like any other Rust test using `cargo test --test sqllogictests`. `sqllogictests` tests may be less convenient for new contributors who are familiar with writing `.rs` tests as they require learning another tool. However, `sqllogictest` based tests are much easier to develop and maintain as they 1) do not require a slow recompile/link cycle and 2) can be automatically updated via `cargo test --test sqllogictests -- --complete`. @@ -221,8 +221,8 @@ Below is a checklist of what you need to do to add a new scalar function to Data - a new line in `signature` with the signature of the function (number and types of its arguments) - a new line in `create_physical_expr`/`create_physical_fun` mapping the built-in to the implementation - tests to the function. -- In [core/tests/sqllogictests/test_files](../../../datafusion/core/tests/sqllogictests/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. - - Documentation for `sqllogictest` [here](../../../datafusion/core/tests/sqllogictests/README.md) +- In [sqllogictest/test_files](../../../datafusion/sqllogictest/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. + - Documentation for `sqllogictest` [here](../../../datafusion/sqllogictest/README.md) - In [expr/src/expr_fn.rs](../../../datafusion/expr/src/expr_fn.rs), add: - a new entry of the `unary_scalar_expr!` macro for the new function. - Add SQL reference documentation [here](../../../docs/source/user-guide/sql/scalar_functions.md) @@ -243,8 +243,8 @@ Below is a checklist of what you need to do to add a new aggregate function to D - a new line in `signature` with the signature of the function (number and types of its arguments) - a new line in `create_aggregate_expr` mapping the built-in to the implementation - tests to the function. -- In [core/tests/sqllogictests/test_files](../../../datafusion/core/tests/sqllogictests/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. - - Documentation for `sqllogictest` [here](../../../datafusion/core/tests/sqllogictests/README.md) +- In [sqllogictest/test_files](../../../datafusion/sqllogictest/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. + - Documentation for `sqllogictest` [here](../../../datafusion/sqllogictest/README.md) - Add SQL reference documentation [here](../../../docs/source/user-guide/sql/aggregate_functions.md) ### How to display plans graphically diff --git a/docs/source/index.rst b/docs/source/index.rst index bb8e2127f1e7..385371661716 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -43,11 +43,12 @@ community. The `example usage`_ section in the user guide and the `datafusion-examples`_ code in the crate contain information on using DataFusion. -The `developer’s guide`_ contains information on how to contribute. +Please see the `developer’s guide`_ for contributing and `communication`_ for getting in touch with us. .. _example usage: user-guide/example-usage.html .. _datafusion-examples: https://github.com/apache/arrow-datafusion/tree/master/datafusion-examples .. _developer’s guide: contributor-guide/index.html#developer-s-guide +.. _communication: contributor-guide/communication.html .. _toc.links: .. toctree:: diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index d3f31bd45aee..1f687f978f30 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -17,17 +17,18 @@ under the License. --> -# Adding User Defined Functions: Scalar/Window/Aggregate +# Adding User Defined Functions: Scalar/Window/Aggregate/Table Functions User Defined Functions (UDFs) are functions that can be used in the context of DataFusion execution. This page covers how to add UDFs to DataFusion. In particular, it covers how to add Scalar, Window, and Aggregate UDFs. -| UDF Type | Description | Example | -| --------- | ---------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------- | -| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs](../../../datafusion-examples/examples/simple_udf.rs) | -| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs](../../../datafusion-examples/examples/simple_udwf.rs) | -| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs](../../../datafusion-examples/examples/simple_udaf.rs) | +| UDF Type | Description | Example | +| --------- | ---------------------------------------------------------------------------------------------------------- | ------------------- | +| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs][1] | +| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs][2] | +| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs][3] | +| Table | A function that takes parameters and returns a `TableProvider` to be used in an query plan. | [simple_udtf.rs][4] | First we'll talk about adding an Scalar UDF end-to-end, then we'll talk about the differences between the different types of UDFs. @@ -38,7 +39,7 @@ A Scalar UDF is a function that takes a row of data and returns a single value. ```rust use std::sync::Arc; -use arrow::array::{ArrayRef, Int64Array}; +use datafusion::arrow::array::{ArrayRef, Int64Array}; use datafusion::common::Result; use datafusion::common::cast::as_int64_array; @@ -75,9 +76,16 @@ The challenge however is that DataFusion doesn't know about this function. We ne ### Registering a Scalar UDF -To register a Scalar UDF, you need to wrap the function implementation in a `ScalarUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udf` and `make_scalar_function` helper functions to make this easier. +To register a Scalar UDF, you need to wrap the function implementation in a [`ScalarUDF`] struct and then register it with the `SessionContext`. +DataFusion provides the [`create_udf`] and helper functions to make this easier. +There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udf.rs`]. ```rust +use datafusion::logical_expr::{Volatility, create_udf}; +use datafusion::physical_plan::functions::make_scalar_function; +use datafusion::arrow::datatypes::DataType; +use std::sync::Arc; + let udf = create_udf( "add_one", vec![DataType::Int64], @@ -87,6 +95,11 @@ let udf = create_udf( ); ``` +[`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html +[`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html +[`make_scalar_function`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.make_scalar_function.html +[`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs + A few things to note: - The first argument is the name of the function. This is the name that will be used in SQL queries. @@ -98,6 +111,8 @@ A few things to note: That gives us a `ScalarUDF` that we can register with the `SessionContext`: ```rust +use datafusion::execution::context::SessionContext; + let mut ctx = SessionContext::new(); ctx.register_udf(udf); @@ -115,10 +130,415 @@ let df = ctx.sql(&sql).await.unwrap(); Scalar UDFs are functions that take a row of data and return a single value. Window UDFs are similar, but they also have access to the rows around them. Access to the the proximal rows is helpful, but adds some complexity to the implementation. -Body coming soon. +For example, we will declare a user defined window function that computes a moving average. + +```rust +use datafusion::arrow::{array::{ArrayRef, Float64Array, AsArray}, datatypes::Float64Type}; +use datafusion::logical_expr::{PartitionEvaluator}; +use datafusion::common::ScalarValue; +use datafusion::error::Result; +/// This implements the lowest level evaluation for a window function +/// +/// It handles calculating the value of the window function for each +/// distinct values of `PARTITION BY` +#[derive(Clone, Debug)] +struct MyPartitionEvaluator {} + +impl MyPartitionEvaluator { + fn new() -> Self { + Self {} + } +} + +/// Different evaluation methods are called depending on the various +/// settings of WindowUDF. This example uses the simplest and most +/// general, `evaluate`. See `PartitionEvaluator` for the other more +/// advanced uses. +impl PartitionEvaluator for MyPartitionEvaluator { + /// Tell DataFusion the window function varies based on the value + /// of the window frame. + fn uses_window_frame(&self) -> bool { + true + } + + /// This function is called once per input row. + /// + /// `range`specifies which indexes of `values` should be + /// considered for the calculation. + /// + /// Note this is the SLOWEST, but simplest, way to evaluate a + /// window function. It is much faster to implement + /// evaluate_all or evaluate_all_with_rank, if possible + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &std::ops::Range, + ) -> Result { + // Again, the input argument is an array of floating + // point numbers to calculate a moving average + let arr: &Float64Array = values[0].as_ref().as_primitive::(); + + let range_len = range.end - range.start; + + // our smoothing function will average all the values in the + let output = if range_len > 0 { + let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); + Some(sum / range_len as f64) + } else { + None + }; + + Ok(ScalarValue::Float64(output)) + } +} + +/// Create a `PartitionEvalutor` to evaluate this function on a new +/// partition. +fn make_partition_evaluator() -> Result> { + Ok(Box::new(MyPartitionEvaluator::new())) +} +``` + +### Registering a Window UDF + +To register a Window UDF, you need to wrap the function implementation in a [`WindowUDF`] struct and then register it with the `SessionContext`. DataFusion provides the [`create_udwf`] helper functions to make this easier. +There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udwf.rs`]. + +```rust +use datafusion::logical_expr::{Volatility, create_udwf}; +use datafusion::arrow::datatypes::DataType; +use std::sync::Arc; + +// here is where we define the UDWF. We also declare its signature: +let smooth_it = create_udwf( + "smooth_it", + DataType::Float64, + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(make_partition_evaluator), +); +``` + +[`windowudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.WindowUDF.html +[`create_udwf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udwf.html +[`advanced_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs + +The `create_udwf` has five arguments to check: + +- The first argument is the name of the function. This is the name that will be used in SQL queries. +- **The second argument** is the `DataType` of input array (attention: this is not a list of arrays). I.e. in this case, the function accepts `Float64` as argument. +- The third argument is the return type of the function. I.e. in this case, the function returns an `Float64`. +- The fourth argument is the volatility of the function. In short, this is used to determine if the function's performance can be optimized in some situations. In this case, the function is `Immutable` because it always returns the same value for the same input. A random number generator would be `Volatile` because it returns a different value for the same input. +- **The fifth argument** is the function implementation. This is the function that we defined above. + +That gives us a `WindowUDF` that we can register with the `SessionContext`: + +```rust +use datafusion::execution::context::SessionContext; + +let ctx = SessionContext::new(); + +ctx.register_udwf(smooth_it); +``` + +At this point, you can use the `smooth_it` function in your query: + +For example, if we have a [`cars.csv`](https://github.com/apache/arrow-datafusion/blob/main/datafusion/core/tests/data/cars.csv) whose contents like + +```csv +car,speed,time +red,20.0,1996-04-12T12:05:03.000000000 +red,20.3,1996-04-12T12:05:04.000000000 +green,10.0,1996-04-12T12:05:03.000000000 +green,10.3,1996-04-12T12:05:04.000000000 +... +``` + +Then, we can query like below: + +```rust +use datafusion::datasource::file_format::options::CsvReadOptions; +// register csv table first +let csv_path = "cars.csv".to_string(); +ctx.register_csv("cars", &csv_path, CsvReadOptions::default().has_header(true)).await?; +// do query with smooth_it +let df = ctx + .sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time) as smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; +// print the results +df.show().await?; +``` + +the output will be like: + +```csv ++-------+-------+--------------------+---------------------+ +| car | speed | smooth_speed | time | ++-------+-------+--------------------+---------------------+ +| green | 10.0 | 10.0 | 1996-04-12T12:05:03 | +| green | 10.3 | 10.15 | 1996-04-12T12:05:04 | +| green | 10.4 | 10.233333333333334 | 1996-04-12T12:05:05 | +| green | 10.5 | 10.3 | 1996-04-12T12:05:06 | +| green | 11.0 | 10.440000000000001 | 1996-04-12T12:05:07 | +| green | 12.0 | 10.700000000000001 | 1996-04-12T12:05:08 | +| green | 14.0 | 11.171428571428573 | 1996-04-12T12:05:09 | +| green | 15.0 | 11.65 | 1996-04-12T12:05:10 | +| green | 15.1 | 12.033333333333333 | 1996-04-12T12:05:11 | +| green | 15.2 | 12.35 | 1996-04-12T12:05:12 | +| green | 8.0 | 11.954545454545455 | 1996-04-12T12:05:13 | +| green | 2.0 | 11.125 | 1996-04-12T12:05:14 | +| red | 20.0 | 20.0 | 1996-04-12T12:05:03 | +| red | 20.3 | 20.15 | 1996-04-12T12:05:04 | +... +``` ## Adding an Aggregate UDF Aggregate UDFs are functions that take a group of rows and return a single value. These are akin to SQL's `SUM` or `COUNT` functions. -Body coming soon. +For example, we will declare a single-type, single return type UDAF that computes the geometric mean. + +```rust +use datafusion::arrow::array::ArrayRef; +use datafusion::scalar::ScalarValue; +use datafusion::{error::Result, physical_plan::Accumulator}; + +/// A UDAF has state across multiple rows, and thus we require a `struct` with that state. +#[derive(Debug)] +struct GeometricMean { + n: u32, + prod: f64, +} + +impl GeometricMean { + // how the struct is initialized + pub fn new() -> Self { + GeometricMean { n: 0, prod: 1.0 } + } +} + +// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions +// to use them. +impl Accumulator for GeometricMean { + // This function serializes our state to `ScalarValue`, which DataFusion uses + // to pass this state between execution stages. + // Note that this can be arbitrary data. + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.prod), + ScalarValue::from(self.n), + ]) + } + + // DataFusion expects this function to return the final value of this aggregator. + // in this case, this is the formula of the geometric mean + fn evaluate(&self) -> Result { + let value = self.prod.powf(1.0 / self.n as f64); + Ok(ScalarValue::from(value)) + } + + // DataFusion calls this function to update the accumulator's state for a batch + // of inputs rows. In this case the product is updated with values from the first column + // and the count is updated based on the row count + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = &values[0]; + (0..arr.len()).try_for_each(|index| { + let v = ScalarValue::try_from_array(arr, index)?; + + if let ScalarValue::Float64(Some(value)) = v { + self.prod *= value; + self.n += 1; + } else { + unreachable!("") + } + Ok(()) + }) + } + + // Optimization hint: this trait also supports `update_batch` and `merge_batch`, + // that can be used to perform these operations on arrays instead of single values. + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + let arr = &states[0]; + (0..arr.len()).try_for_each(|index| { + let v = states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>()?; + if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) = (&v[0], &v[1]) + { + self.prod *= prod; + self.n += n; + } else { + unreachable!("") + } + Ok(()) + }) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} +``` + +### registering an Aggregate UDF + +To register a Aggreate UDF, you need to wrap the function implementation in a `AggregateUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udaf` helper functions to make this easier. + +```rust +use datafusion::logical_expr::{Volatility, create_udaf}; +use datafusion::arrow::datatypes::DataType; +use std::sync::Arc; + +// here is where we define the UDAF. We also declare its signature: +let geometric_mean = create_udaf( + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "geo_mean", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + vec![DataType::Float64], + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Float64), + Volatility::Immutable, + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new(|_| Ok(Box::new(GeometricMean::new()))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), +); +``` + +The `create_udaf` has six arguments to check: + +- The first argument is the name of the function. This is the name that will be used in SQL queries. +- The second argument is a vector of `DataType`s. This is the list of argument types that the function accepts. I.e. in this case, the function accepts a single `Float64` argument. +- The third argument is the return type of the function. I.e. in this case, the function returns an `Int64`. +- The fourth argument is the volatility of the function. In short, this is used to determine if the function's performance can be optimized in some situations. In this case, the function is `Immutable` because it always returns the same value for the same input. A random number generator would be `Volatile` because it returns a different value for the same input. +- The fifth argument is the function implementation. This is the function that we defined above. +- The sixth argument is the description of the state, which will by passed between execution stages. + +That gives us a `AggregateUDF` that we can register with the `SessionContext`: + +```rust +use datafusion::execution::context::SessionContext; + +let ctx = SessionContext::new(); + +ctx.register_udaf(geometric_mean); +``` + +Then, we can query like below: + +```rust +let df = ctx.sql("SELECT geo_mean(a) FROM t").await?; +``` + +## Adding a User-Defined Table Function + +A User-Defined Table Function (UDTF) is a function that takes parameters and returns a `TableProvider`. + +Because we're returning a `TableProvider`, in this example we'll use the `MemTable` data source to represent a table. This is a simple struct that holds a set of RecordBatches in memory and treats them as a table. In your case, this would be replaced with your own struct that implements `TableProvider`. + +While this is a simple example for illustrative purposes, UDTFs have a lot of potential use cases. And can be particularly useful for reading data from external sources and interactive analysis. For example, see the [example][4] for a working example that reads from a CSV file. As another example, you could use the built-in UDTF `parquet_metadata` in the CLI to read the metadata from a Parquet file. + +```console +❯ select filename, row_group_id, row_group_num_rows, row_group_bytes, stats_min, stats_max from parquet_metadata('./benchmarks/data/hits.parquet') where column_id = 17 limit 10; ++--------------------------------+--------------+--------------------+-----------------+-----------+-----------+ +| filename | row_group_id | row_group_num_rows | row_group_bytes | stats_min | stats_max | ++--------------------------------+--------------+--------------------+-----------------+-----------+-----------+ +| ./benchmarks/data/hits.parquet | 0 | 450560 | 188921521 | 0 | 73256 | +| ./benchmarks/data/hits.parquet | 1 | 612174 | 210338885 | 0 | 109827 | +| ./benchmarks/data/hits.parquet | 2 | 344064 | 161242466 | 0 | 122484 | +| ./benchmarks/data/hits.parquet | 3 | 606208 | 235549898 | 0 | 121073 | +| ./benchmarks/data/hits.parquet | 4 | 335872 | 137103898 | 0 | 108996 | +| ./benchmarks/data/hits.parquet | 5 | 311296 | 145453612 | 0 | 108996 | +| ./benchmarks/data/hits.parquet | 6 | 303104 | 138833963 | 0 | 108996 | +| ./benchmarks/data/hits.parquet | 7 | 303104 | 191140113 | 0 | 73256 | +| ./benchmarks/data/hits.parquet | 8 | 573440 | 208038598 | 0 | 95823 | +| ./benchmarks/data/hits.parquet | 9 | 344064 | 147838157 | 0 | 73256 | ++--------------------------------+--------------+--------------------+-----------------+-----------+-----------+ +``` + +### Writing the UDTF + +The simple UDTF used here takes a single `Int64` argument and returns a table with a single column with the value of the argument. To create a function in DataFusion, you need to implement the `TableFunctionImpl` trait. This trait has a single method, `call`, that takes a slice of `Expr`s and returns a `Result>`. + +In the `call` method, you parse the input `Expr`s and return a `TableProvider`. You might also want to do some validation of the input `Expr`s, e.g. checking that the number of arguments is correct. + +```rust +use datafusion::common::plan_err; +use datafusion::datasource::function::TableFunctionImpl; +// Other imports here + +/// A table function that returns a table provider with the value as a single column +#[derive(Default)] +pub struct EchoFunction {} + +impl TableFunctionImpl for EchoFunction { + fn call(&self, exprs: &[Expr]) -> Result> { + let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { + return plan_err!("First argument must be an integer"); + }; + + // Create the schema for the table + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + // Create a single RecordBatch with the value as a single column + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(vec![*value]))], + )?; + + // Create a MemTable plan that returns the RecordBatch + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + + Ok(Arc::new(provider)) + } +} +``` + +### Registering and Using the UDTF + +With the UDTF implemented, you can register it with the `SessionContext`: + +```rust +use datafusion::execution::context::SessionContext; + +let ctx = SessionContext::new(); + +ctx.register_udtf("echo", Arc::new(EchoFunction::default())); +``` + +And if all goes well, you can use it in your query: + +```rust +use datafusion::arrow::util::pretty; + +let df = ctx.sql("SELECT * FROM echo(1)").await?; + +let results = df.collect().await?; +pretty::print_batches(&results)?; +// +---+ +// | a | +// +---+ +// | 1 | +// +---+ +``` + +[1]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs +[2]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs +[3]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs +[4]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udtf.rs diff --git a/docs/source/library-user-guide/building-logical-plans.md b/docs/source/library-user-guide/building-logical-plans.md index 406f4881129c..fe922d8eaeb1 100644 --- a/docs/source/library-user-guide/building-logical-plans.md +++ b/docs/source/library-user-guide/building-logical-plans.md @@ -19,4 +19,131 @@ # Building Logical Plans -Coming Soon +A logical plan is a structured representation of a database query that describes the high-level operations and +transformations needed to retrieve data from a database or data source. It abstracts away specific implementation +details and focuses on the logical flow of the query, including operations like filtering, sorting, and joining tables. + +This logical plan serves as an intermediate step before generating an optimized physical execution plan. This is +explained in more detail in the [Query Planning and Execution Overview] section of the [Architecture Guide]. + +## Building Logical Plans Manually + +DataFusion's [LogicalPlan] is an enum containing variants representing all the supported operators, and also +contains an `Extension` variant that allows projects building on DataFusion to add custom logical operators. + +It is possible to create logical plans by directly creating instances of the [LogicalPlan] enum as follows, but is is +much easier to use the [LogicalPlanBuilder], which is described in the next section. + +Here is an example of building a logical plan directly: + + + +```rust +// create a logical table source +let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), +]); +let table_source = LogicalTableSource::new(SchemaRef::new(schema)); + +// create a TableScan plan +let projection = None; // optional projection +let filters = vec![]; // optional filters to push down +let fetch = None; // optional LIMIT +let table_scan = LogicalPlan::TableScan(TableScan::try_new( + "person", + Arc::new(table_source), + projection, + filters, + fetch, +)?); + +// create a Filter plan that evaluates `id > 500` that wraps the TableScan +let filter_expr = col("id").gt(lit(500)); +let plan = LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(table_scan))?); + +// print the plan +println!("{}", plan.display_indent_schema()); +``` + +This example produces the following plan: + +``` +Filter: person.id > Int32(500) [id:Int32;N, name:Utf8;N] + TableScan: person [id:Int32;N, name:Utf8;N] +``` + +## Building Logical Plans with LogicalPlanBuilder + +DataFusion logical plans can be created using the [LogicalPlanBuilder] struct. There is also a [DataFrame] API which is +a higher-level API that delegates to [LogicalPlanBuilder]. + +The following associated functions can be used to create a new builder: + +- `empty` - create an empty plan with no fields +- `values` - create a plan from a set of literal values +- `scan` - create a plan representing a table scan +- `scan_with_filters` - create a plan representing a table scan with filters + +Once the builder is created, transformation methods can be called to declare that further operations should be +performed on the plan. Note that all we are doing at this stage is building up the logical plan structure. No query +execution will be performed. + +Here are some examples of transformation methods, but for a full list, refer to the [LogicalPlanBuilder] API documentation. + +- `filter` +- `limit` +- `sort` +- `distinct` +- `join` + +The following example demonstrates building the same simple query plan as the previous example, with a table scan followed by a filter. + + + +```rust +// create a logical table source +let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), +]); +let table_source = LogicalTableSource::new(SchemaRef::new(schema)); + +// optional projection +let projection = None; + +// create a LogicalPlanBuilder for a table scan +let builder = LogicalPlanBuilder::scan("person", Arc::new(table_source), projection)?; + +// perform a filter operation and build the plan +let plan = builder + .filter(col("id").gt(lit(500)))? // WHERE id > 500 + .build()?; + +// print the plan +println!("{}", plan.display_indent_schema()); +``` + +This example produces the following plan: + +``` +Filter: person.id > Int32(500) [id:Int32;N, name:Utf8;N] + TableScan: person [id:Int32;N, name:Utf8;N] +``` + +## Table Sources + +The previous example used a [LogicalTableSource], which is used for tests and documentation in DataFusion, and is also +suitable if you are using DataFusion to build logical plans but do not use DataFusion's physical planner. However, if you +want to use a [TableSource] that can be executed in DataFusion then you will need to use [DefaultTableSource], which is a +wrapper for a [TableProvider]. + +[query planning and execution overview]: https://docs.rs/datafusion/latest/datafusion/index.html#query-planning-and-execution-overview +[architecture guide]: https://docs.rs/datafusion/latest/datafusion/index.html#architecture +[logicalplan]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html +[logicalplanbuilder]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/builder/struct.LogicalPlanBuilder.html +[dataframe]: using-the-dataframe-api.md +[logicaltablesource]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/builder/struct.LogicalTableSource.html +[defaulttablesource]: https://docs.rs/datafusion/latest/datafusion/datasource/default_table_source/struct.DefaultTableSource.html +[tableprovider]: https://docs.rs/datafusion/latest/datafusion/datasource/provider/trait.TableProvider.html +[tablesource]: https://docs.rs/datafusion-expr/latest/datafusion_expr/trait.TableSource.html diff --git a/docs/source/library-user-guide/catalogs.md b/docs/source/library-user-guide/catalogs.md index 1dd235f0a2d2..e53d16366350 100644 --- a/docs/source/library-user-guide/catalogs.md +++ b/docs/source/library-user-guide/catalogs.md @@ -19,7 +19,7 @@ # Catalogs, Schemas, and Tables -This section describes how to create and manage catalogs, schemas, and tables in DataFusion. For those wanting to dive into the code quickly please see the [example](../../../datafusion-examples/examples/catalog.rs). +This section describes how to create and manage catalogs, schemas, and tables in DataFusion. For those wanting to dive into the code quickly please see the [example](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/catalog.rs). ## General Concepts diff --git a/docs/source/library-user-guide/custom-table-providers.md b/docs/source/library-user-guide/custom-table-providers.md index 30721d6a5ba6..9da207da68f3 100644 --- a/docs/source/library-user-guide/custom-table-providers.md +++ b/docs/source/library-user-guide/custom-table-providers.md @@ -25,7 +25,7 @@ This section will also touch on how to have DataFusion use the new `TableProvide ## Table Provider and Scan -The `scan` method on the `TableProvider` is likely its most important. It returns an `ExecutionPlan` that DataFusion will use to read the actual data during execution o the query. +The `scan` method on the `TableProvider` is likely its most important. It returns an `ExecutionPlan` that DataFusion will use to read the actual data during execution of the query. ### Scan @@ -121,12 +121,28 @@ impl TableProvider for CustomDataSource { With this, and the implementation of the omitted methods, we can now use the `CustomDataSource` as a `TableProvider` in DataFusion. +##### Additional `TableProvider` Methods + +`scan` has no default implementation, so it needed to be written. There are other methods on the `TableProvider` that have default implementations, but can be overridden if needed to provide additional functionality. + +###### `supports_filters_pushdown` + +The `supports_filters_pushdown` method can be overridden to indicate which filter expressions support being pushed down to the data source and within that the specificity of the pushdown. + +This returns a `Vec` of `TableProviderFilterPushDown` enums where each enum represents a filter that can be pushed down. The `TableProviderFilterPushDown` enum has three variants: + +- `TableProviderFilterPushDown::Unsupported` - the filter cannot be pushed down +- `TableProviderFilterPushDown::Exact` - the filter can be pushed down and the data source can guarantee that the filter will be applied completely to all rows. This is the highest performance option. +- `TableProviderFilterPushDown::Inexact` - the filter can be pushed down, but the data source cannot guarantee that the filter will be applied to all rows. DataFusion will apply `Inexact` filters again after the scan to ensure correctness. + +For filters that can be pushed down, they'll be passed to the `scan` method as the `filters` parameter and they can be made use of there. + ## Using the Custom Table Provider -In order to use the custom table provider, we need to register it with DataFusion. This is done by creating a `TableProvider` and registering it with the `ExecutionContext`. +In order to use the custom table provider, we need to register it with DataFusion. This is done by creating a `TableProvider` and registering it with the `SessionContext`. ```rust -let mut ctx = ExecutionContext::new(); +let mut ctx = SessionContext::new(); let custom_table_provider = CustomDataSource::new(); ctx.register_table("custom_table", Arc::new(custom_table_provider)); @@ -144,7 +160,7 @@ To recap, in order to implement a custom table provider, you need to: 1. Implement the `TableProvider` trait 2. Implement the `ExecutionPlan` trait -3. Register the `TableProvider` with the `ExecutionContext` +3. Register the `TableProvider` with the `SessionContext` ## Next Steps diff --git a/docs/source/library-user-guide/using-the-dataframe-api.md b/docs/source/library-user-guide/using-the-dataframe-api.md index fdf309980dc2..c4f4ecd4f137 100644 --- a/docs/source/library-user-guide/using-the-dataframe-api.md +++ b/docs/source/library-user-guide/using-the-dataframe-api.md @@ -19,4 +19,129 @@ # Using the DataFrame API -Coming Soon +## What is a DataFrame + +`DataFrame` in `DataFrame` is modeled after the Pandas DataFrame interface, and is a thin wrapper over LogicalPlan that adds functionality for building and executing those plans. + +```rust +pub struct DataFrame { + session_state: SessionState, + plan: LogicalPlan, +} +``` + +You can build up `DataFrame`s using its methods, similarly to building `LogicalPlan`s using `LogicalPlanBuilder`: + +```rust +let df = ctx.table("users").await?; + +// Create a new DataFrame sorted by `id`, `bank_account` +let new_df = df.select(vec![col("id"), col("bank_account")])? + .sort(vec![col("id")])?; + +// Build the same plan using the LogicalPlanBuilder +let plan = LogicalPlanBuilder::from(&df.to_logical_plan()) + .project(vec![col("id"), col("bank_account")])? + .sort(vec![col("id")])? + .build()?; +``` + +You can use `collect` or `execute_stream` to execute the query. + +## How to generate a DataFrame + +You can directly use the `DataFrame` API or generate a `DataFrame` from a SQL query. + +For example, to use `sql` to construct `DataFrame`: + +```rust +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(create_memtable()?))?; +let dataframe = ctx.sql("SELECT * FROM users;").await?; +``` + +To construct `DataFrame` using the API: + +```rust +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(create_memtable()?))?; +let dataframe = ctx + .table("users") + .filter(col("a").lt_eq(col("b")))? + .sort(vec![col("a").sort(true, true), col("b").sort(false, false)])?; +``` + +## Collect / Streaming Exec + +DataFusion `DataFrame`s are "lazy", meaning they do not do any processing until they are executed, which allows for additional optimizations. + +When you have a `DataFrame`, you can run it in one of three ways: + +1. `collect` which executes the query and buffers all the output into a `Vec` +2. `streaming_exec`, which begins executions and returns a `SendableRecordBatchStream` which incrementally computes output on each call to `next()` +3. `cache` which executes the query and buffers the output into a new in memory DataFrame. + +You can just collect all outputs once like: + +```rust +let ctx = SessionContext::new(); +let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; +let batches = df.collect().await?; +``` + +You can also use stream output to incrementally generate output one `RecordBatch` at a time + +```rust +let ctx = SessionContext::new(); +let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; +let mut stream = df.execute_stream().await?; +while let Some(rb) = stream.next().await { + println!("{rb:?}"); +} +``` + +# Write DataFrame to Files + +You can also serialize `DataFrame` to a file. For now, `Datafusion` supports write `DataFrame` to `csv`, `json` and `parquet`. + +When writing a file, DataFusion will execute the DataFrame and stream the results to a file. + +For example, to write a csv_file + +```rust +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(mem_table))?; +let dataframe = ctx.sql("SELECT * FROM users;").await?; + +dataframe + .write_csv("user_dataframe.csv", DataFrameWriteOptions::default(), None) + .await; +``` + +and the file will look like (Example Output): + +``` +id,bank_account +1,9000 +``` + +## Transform between LogicalPlan and DataFrame + +As shown above, `DataFrame` is just a very thin wrapper of `LogicalPlan`, so you can easily go back and forth between them. + +```rust +// Just combine LogicalPlan with SessionContext and you get a DataFrame +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(mem_table))?; +let dataframe = ctx.sql("SELECT * FROM users;").await?; + +// get LogicalPlan in dataframe +let plan = dataframe.logical_plan().clone(); + +// construct a DataFrame with LogicalPlan +let new_df = DataFrame::new(ctx.state(), plan); +``` diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index 507e984acb0b..96be8ef7f1ae 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -17,7 +17,7 @@ under the License. --> -# Working with Exprs +# Working with `Expr`s @@ -48,12 +48,11 @@ As another example, the SQL expression `a + b * c` would be represented as an `E └────────────────────┘ └────────────────────┘ ``` -As the writer of a library, you may want to use or create `Expr`s to represent computations that you want to perform. This guide will walk you through how to make your own scalar UDF as an `Expr` and how to rewrite `Expr`s to inline the simple UDF. +As the writer of a library, you can use `Expr`s to represent computations that you want to perform. This guide will walk you through how to make your own scalar UDF as an `Expr` and how to rewrite `Expr`s to inline the simple UDF. -There are also executable examples for working with `Expr`s: +## Creating and Evaluating `Expr`s -- [rewrite_expr.rs](../../../datafusion-examples/examples/catalog.rs) -- [expr_api.rs](../../../datafusion-examples/examples/expr_api.rs) +Please see [expr_api.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs) for well commented code for creating, evaluating, simplifying, and analyzing `Expr`s. ## A Scalar UDF Example @@ -79,7 +78,9 @@ let expr = add_one_udf.call(vec![col("my_column")]); If you'd like to learn more about `Expr`s, before we get into the details of creating and rewriting them, you can read the [expression user-guide](./../user-guide/expressions.md). -## Rewriting Exprs +## Rewriting `Expr`s + +[rewrite_expr.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/rewrite_expr.rs) contains example code for rewriting `Expr`s. Rewriting Expressions is the process of taking an `Expr` and transforming it into another `Expr`. This is useful for a number of reasons, including: diff --git a/docs/source/user-guide/cli.md b/docs/source/user-guide/cli.md index 05b4165e612f..525ab090ce51 100644 --- a/docs/source/user-guide/cli.md +++ b/docs/source/user-guide/cli.md @@ -31,7 +31,9 @@ The easiest way to install DataFusion CLI a spin is via `cargo install datafusio ### Install and run using Homebrew (on MacOS) -DataFusion CLI can also be installed via Homebrew (on MacOS). Install it as any other pre-built software like this: +DataFusion CLI can also be installed via Homebrew (on MacOS). If you don't have Homebrew installed, you can check how to install it [here](https://docs.brew.sh/Installation). + +Install it as any other pre-built software like this: ```bash brew install datafusion @@ -46,6 +48,34 @@ brew install datafusion datafusion-cli ``` +### Install and run using PyPI + +DataFusion CLI can also be installed via PyPI. You can check how to install PyPI [here](https://pip.pypa.io/en/latest/installation/). + +Install it as any other pre-built software like this: + +```bash +pip3 install datafusion +# Defaulting to user installation because normal site-packages is not writeable +# Collecting datafusion +# Downloading datafusion-33.0.0-cp38-abi3-macosx_11_0_arm64.whl.metadata (9.6 kB) +# Collecting pyarrow>=11.0.0 (from datafusion) +# Downloading pyarrow-14.0.1-cp39-cp39-macosx_11_0_arm64.whl.metadata (3.0 kB) +# Requirement already satisfied: numpy>=1.16.6 in /Users/Library/Python/3.9/lib/python/site-packages (from pyarrow>=11.0.0->datafusion) (1.23.4) +# Downloading datafusion-33.0.0-cp38-abi3-macosx_11_0_arm64.whl (13.5 MB) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.5/13.5 MB 3.6 MB/s eta 0:00:00 +# Downloading pyarrow-14.0.1-cp39-cp39-macosx_11_0_arm64.whl (24.0 MB) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.0/24.0 MB 36.4 MB/s eta 0:00:00 +# Installing collected packages: pyarrow, datafusion +# Attempting uninstall: pyarrow +# Found existing installation: pyarrow 10.0.1 +# Uninstalling pyarrow-10.0.1: +# Successfully uninstalled pyarrow-10.0.1 +# Successfully installed datafusion-33.0.0 pyarrow-14.0.1 + +datafusion-cli +``` + ### Run using Docker There is no officially published Docker image for the DataFusion CLI, so it is necessary to build from source @@ -397,11 +427,13 @@ Available commands inside DataFusion CLI are: - Show configuration options +`SHOW ALL [VERBOSE]` + ```SQL > show all; +-------------------------------------------------+---------+ -| name | setting | +| name | value | +-------------------------------------------------+---------+ | datafusion.execution.batch_size | 8192 | | datafusion.execution.coalesce_batches | true | @@ -414,6 +446,21 @@ Available commands inside DataFusion CLI are: ``` +- Show specific configuration option + +`SHOW xyz.abc.qwe [VERBOSE]` + +```SQL +> show datafusion.execution.batch_size; + ++-------------------------------------------------+---------+ +| name | value | ++-------------------------------------------------+---------+ +| datafusion.execution.batch_size | 8192 | ++-------------------------------------------------+---------+ + +``` + - Set configuration options ```SQL @@ -432,12 +479,12 @@ For example, to set `datafusion.execution.batch_size` to `1024` you would set the `DATAFUSION_EXECUTION_BATCH_SIZE` environment variable appropriately: -```shell +```SQL $ DATAFUSION_EXECUTION_BATCH_SIZE=1024 datafusion-cli DataFusion CLI v12.0.0 ❯ show all; +-------------------------------------------------+---------+ -| name | setting | +| name | value | +-------------------------------------------------+---------+ | datafusion.execution.batch_size | 1024 | | datafusion.execution.coalesce_batches | true | @@ -452,13 +499,13 @@ DataFusion CLI v12.0.0 You can change the configuration options using `SET` statement as well -```shell +```SQL $ datafusion-cli DataFusion CLI v13.0.0 ❯ show datafusion.execution.batch_size; +---------------------------------+---------+ -| name | setting | +| name | value | +---------------------------------+---------+ | datafusion.execution.batch_size | 8192 | +---------------------------------+---------+ @@ -469,7 +516,7 @@ DataFusion CLI v13.0.0 ❯ show datafusion.execution.batch_size; +---------------------------------+---------+ -| name | setting | +| name | value | +---------------------------------+---------+ | datafusion.execution.batch_size | 1024 | +---------------------------------+---------+ diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 7fe229b4d3c6..0a5c221c5034 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -35,67 +35,75 @@ Values are parsed according to the [same rules used in casts from Utf8](https:// If the value in the environment variable cannot be cast to the type of the configuration option, the default value will be used instead and a warning emitted. Environment variables are read during `SessionConfig` initialisation so they must be set beforehand and will not affect running sessions. -| key | default | description | -| ---------------------------------------------------------- | ------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| datafusion.catalog.create_default_catalog_and_schema | true | Whether the default catalog and schema should be created automatically. | -| datafusion.catalog.default_catalog | datafusion | The default catalog name - this impacts what SQL queries use if not specified | -| datafusion.catalog.default_schema | public | The default schema name - this impacts what SQL queries use if not specified | -| datafusion.catalog.information_schema | false | Should DataFusion provide access to `information_schema` virtual tables for displaying schema information | -| datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | -| datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | -| datafusion.catalog.has_header | false | If the file has a header | -| datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | -| datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | -| datafusion.execution.collect_statistics | false | Should DataFusion collect statistics after listing files | -| datafusion.execution.target_partitions | 0 | Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system | -| datafusion.execution.time_zone | +00:00 | The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour | -| datafusion.execution.parquet.enable_page_index | true | If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | -| datafusion.execution.parquet.pruning | true | If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | -| datafusion.execution.parquet.skip_metadata | true | If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | -| datafusion.execution.parquet.metadata_size_hint | NULL | If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | -| datafusion.execution.parquet.pushdown_filters | false | If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded | -| datafusion.execution.parquet.reorder_filters | false | If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | -| datafusion.execution.parquet.data_pagesize_limit | 1048576 | Sets best effort maximum size of data page in bytes | -| datafusion.execution.parquet.write_batch_size | 1024 | Sets write_batch_size in bytes | -| datafusion.execution.parquet.writer_version | 1.0 | Sets parquet writer version valid values are "1.0" and "2.0" | -| datafusion.execution.parquet.compression | NULL | Sets default parquet compression codec Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.dictionary_enabled | NULL | Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | Sets best effort maximum dictionary page size, in bytes | -| datafusion.execution.parquet.statistics_enabled | NULL | Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.max_statistics_size | NULL | Sets max statistics size for any column. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.max_row_group_size | 1048576 | Sets maximum number of rows in a row group | -| datafusion.execution.parquet.created_by | datafusion version 31.0.0 | Sets "created by" property | -| datafusion.execution.parquet.column_index_truncate_length | NULL | Sets column index trucate length | -| datafusion.execution.parquet.data_page_row_count_limit | 18446744073709551615 | Sets best effort maximum number of rows in data page | -| datafusion.execution.parquet.encoding | NULL | Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.bloom_filter_enabled | false | Sets if bloom filter is enabled for any column | -| datafusion.execution.parquet.bloom_filter_fpp | NULL | Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.bloom_filter_ndv | NULL | Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.allow_single_file_parallelism | false | Controls whether DataFusion will attempt to speed up writing large parquet files by first writing multiple smaller files and then stitching them together into a single large file. This will result in faster write speeds, but higher memory usage. Also currently unsupported are bloom filters and column indexes when single_file_parallelism is enabled. | -| datafusion.execution.aggregate.scalar_update_factor | 10 | Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. | -| datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | -| datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | -| datafusion.execution.sort_in_place_threshold_bytes | 1048576 | When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. | -| datafusion.execution.meta_fetch_concurrency | 32 | Number of files to read in parallel when inferring schema and statistics | -| datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | -| datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | -| datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | -| datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | -| datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. | -| datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level | -| datafusion.optimizer.allow_symmetric_joins_without_pruning | true | Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. | -| datafusion.optimizer.repartition_file_scans | true | When set to `true`, file groups will be repartitioned to achieve maximum parallelism. Currently Parquet and CSV formats are supported. If set to `true`, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false`, different files will be read in parallel, but repartitioning won't happen within a single file. | -| datafusion.optimizer.repartition_windows | true | Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level | -| datafusion.optimizer.repartition_sorts | true | Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below `text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` would turn into the plan below which performs better in multithreaded environments `text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` | -| datafusion.optimizer.bounded_order_preserving_variants | false | When true, DataFusion will opportunistically remove sorts by replacing `RepartitionExec` with `SortPreservingRepartitionExec`, and `CoalescePartitionsExec` with `SortPreservingMergeExec`, even when the query is bounded. | -| datafusion.optimizer.skip_failed_rules | false | When set to true, the logical plan optimizer will produce warning messages if any optimization rules produce errors and then proceed to the next rule. When set to false, any rules that produce errors will cause the query to fail | -| datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | -| datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | -| datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | -| datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | -| datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | -| datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | -| datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | -| datafusion.sql_parser.parse_float_as_decimal | false | When set to true, SQL parser will parse float as decimal type | -| datafusion.sql_parser.enable_ident_normalization | true | When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) | -| datafusion.sql_parser.dialect | generic | Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. | +| key | default | description | +| ----------------------------------------------------------------------- | ------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| datafusion.catalog.create_default_catalog_and_schema | true | Whether the default catalog and schema should be created automatically. | +| datafusion.catalog.default_catalog | datafusion | The default catalog name - this impacts what SQL queries use if not specified | +| datafusion.catalog.default_schema | public | The default schema name - this impacts what SQL queries use if not specified | +| datafusion.catalog.information_schema | false | Should DataFusion provide access to `information_schema` virtual tables for displaying schema information | +| datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | +| datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | +| datafusion.catalog.has_header | false | If the file has a header | +| datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | +| datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | +| datafusion.execution.collect_statistics | false | Should DataFusion collect statistics after listing files | +| datafusion.execution.target_partitions | 0 | Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system | +| datafusion.execution.time_zone | +00:00 | The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour | +| datafusion.execution.parquet.enable_page_index | true | If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | +| datafusion.execution.parquet.pruning | true | If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | +| datafusion.execution.parquet.skip_metadata | true | If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | +| datafusion.execution.parquet.metadata_size_hint | NULL | If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | +| datafusion.execution.parquet.pushdown_filters | false | If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded | +| datafusion.execution.parquet.reorder_filters | false | If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | +| datafusion.execution.parquet.data_pagesize_limit | 1048576 | Sets best effort maximum size of data page in bytes | +| datafusion.execution.parquet.write_batch_size | 1024 | Sets write_batch_size in bytes | +| datafusion.execution.parquet.writer_version | 1.0 | Sets parquet writer version valid values are "1.0" and "2.0" | +| datafusion.execution.parquet.compression | zstd(3) | Sets default parquet compression codec Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.dictionary_enabled | NULL | Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | Sets best effort maximum dictionary page size, in bytes | +| datafusion.execution.parquet.statistics_enabled | NULL | Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.max_statistics_size | NULL | Sets max statistics size for any column. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.max_row_group_size | 1048576 | Sets maximum number of rows in a row group | +| datafusion.execution.parquet.created_by | datafusion version 34.0.0 | Sets "created by" property | +| datafusion.execution.parquet.column_index_truncate_length | NULL | Sets column index truncate length | +| datafusion.execution.parquet.data_page_row_count_limit | 18446744073709551615 | Sets best effort maximum number of rows in data page | +| datafusion.execution.parquet.encoding | NULL | Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_enabled | false | Sets if bloom filter is enabled for any column | +| datafusion.execution.parquet.bloom_filter_fpp | NULL | Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_ndv | NULL | Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.allow_single_file_parallelism | true | Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | +| datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.aggregate.scalar_update_factor | 10 | Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. | +| datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | +| datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | +| datafusion.execution.sort_in_place_threshold_bytes | 1048576 | When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. | +| datafusion.execution.meta_fetch_concurrency | 32 | Number of files to read in parallel when inferring schema and statistics | +| datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | +| datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | +| datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | +| datafusion.execution.listing_table_ignore_subdirectory | true | Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). | +| datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | +| datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | +| datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | +| datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | +| datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | +| datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. | +| datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level | +| datafusion.optimizer.allow_symmetric_joins_without_pruning | true | Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. | +| datafusion.optimizer.repartition_file_scans | true | When set to `true`, file groups will be repartitioned to achieve maximum parallelism. Currently Parquet and CSV formats are supported. If set to `true`, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false`, different files will be read in parallel, but repartitioning won't happen within a single file. | +| datafusion.optimizer.repartition_windows | true | Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level | +| datafusion.optimizer.repartition_sorts | true | Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below `text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` would turn into the plan below which performs better in multithreaded environments `text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` | +| datafusion.optimizer.prefer_existing_sort | false | When true, DataFusion will opportunistically remove sorts when the data is already sorted, (i.e. setting `preserve_order` to true on `RepartitionExec` and using `SortPreservingMergeExec`) When false, DataFusion will maximize plan parallelism using `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. | +| datafusion.optimizer.skip_failed_rules | false | When set to true, the logical plan optimizer will produce warning messages if any optimization rules produce errors and then proceed to the next rule. When set to false, any rules that produce errors will cause the query to fail | +| datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | +| datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | +| datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | +| datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | +| datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | +| datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | +| datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | +| datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | +| datafusion.sql_parser.parse_float_as_decimal | false | When set to true, SQL parser will parse float as decimal type | +| datafusion.sql_parser.enable_ident_normalization | true | When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) | +| datafusion.sql_parser.dialect | generic | Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. | diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index 4484b2c51019..c0210200a246 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -95,6 +95,7 @@ These methods execute the logical plan represented by the DataFrame and either c | write_csv | Execute this DataFrame and write the results to disk in CSV format. | | write_json | Execute this DataFrame and write the results to disk in JSON format. | | write_parquet | Execute this DataFrame and write the results to disk in Parquet format. | +| write_table | Execute this DataFrame and write the results via the insert_into method of the registered TableProvider | ## Other DataFrame Methods diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index adaf780558bc..a7557f9b0bc3 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -19,9 +19,9 @@ # Example Usage -In this example some simple processing is performed on the [`example.csv`](../../../datafusion/core/tests/data/example.csv) file. +In this example some simple processing is performed on the [`example.csv`](https://github.com/apache/arrow-datafusion/blob/main/datafusion/core/tests/data/example.csv) file. -Even [`more code examples`](../../../datafusion-examples) attached to the project +Even [`more code examples`](https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples) attached to the project. ## Update `Cargo.toml` @@ -187,10 +187,6 @@ DataFusion is designed to be extensible at all points. To that end, you can prov - [x] User Defined `LogicalPlan` nodes - [x] User Defined `ExecutionPlan` nodes -## Rust Version Compatibility - -This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions of the Rust compiler. - ## Optimized Configuration For an optimized build several steps are required. First, use the below in your `Cargo.toml`. It is diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index dbd8c814b46e..b8689e556741 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -22,60 +22,94 @@ DataFrame methods such as `select` and `filter` accept one or more logical expressions and there are many functions available for creating logical expressions. These are documented below. -Expressions can be chained together using a fluent-style API: +:::{tip} +Most functions and methods may receive and return an `Expr`, which can be chained together using a fluent-style API: ```rust // create the expression `(a > 6) AND (b < 7)` col("a").gt(lit(6)).and(col("b").lt(lit(7))) ``` +::: + ## Identifiers -| Function | Notes | -| -------- | -------------------------------------------- | -| col | Reference a column in a dataframe `col("a")` | +| Syntax | Description | +| ---------- | -------------------------------------------- | +| col(ident) | Reference a column in a dataframe `col("a")` | + +:::{note} +ident +: A type which implement `Into` trait +::: ## Literal Values -| Function | Notes | -| -------- | -------------------------------------------------- | -| lit | Literal value such as `lit(123)` or `lit("hello")` | +| Syntax | Description | +| ---------- | -------------------------------------------------- | +| lit(value) | Literal value such as `lit(123)` or `lit("hello")` | + +:::{note} +value +: A type which implement `Literal` +::: ## Boolean Expressions -| Function | Notes | -| -------- | ----------------------------------------- | -| and | `and(expr1, expr2)` or `expr1.and(expr2)` | -| or | `or(expr1, expr2)` or `expr1.or(expr2)` | -| not | `not(expr)` or `expr.not()` | +| Syntax | Description | +| ------------------- | ----------- | +| and(x, y), x.and(y) | Logical AND | +| or(x, y), x.or(y) | Logical OR | +| !x, not(x), x.not() | Logical NOT | + +:::{note} +`!` is a bitwise or logical complement operator in Rust, but it only works as a logical NOT in expression API. +::: + +:::{note} +Since `&&` and `||` are existed as logical operators in Rust, but those are not overloadable and not works with expression API. +::: -## Bitwise expressions +## Bitwise Expressions -| Function | Notes | -| ------------------- | ------------------------------------------------------------------------- | -| bitwise_and | `bitwise_and(expr1, expr2)` or `expr1.bitwise_and(expr2)` | -| bitwise_or | `bitwise_or(expr1, expr2)` or `expr1.bitwise_or(expr2)` | -| bitwise_xor | `bitwise_xor(expr1, expr2)` or `expr1.bitwise_xor(expr2)` | -| bitwise_shift_right | `bitwise_shift_right(expr1, expr2)` or `expr1.bitwise_shift_right(expr2)` | -| bitwise_shift_left | `bitwise_shift_left(expr1, expr2)` or `expr1.bitwise_shift_left(expr2)` | +| Syntax | Description | +| ------------------------------------------- | ----------- | +| x & y, bitwise_and(x, y), x.bitand(y) | AND | +| x \| y, bitwise_or(x, y), x.bitor(y) | OR | +| x ^ y, bitwise_xor(x, y), x.bitxor(y) | XOR | +| x << y, bitwise_shift_left(x, y), x.shl(y) | Left shift | +| x >> y, bitwise_shift_right(x, y), x.shr(y) | Right shift | ## Comparison Expressions -| Function | Notes | -| -------- | --------------------- | -| eq | `expr1.eq(expr2)` | -| gt | `expr1.gt(expr2)` | -| gt_eq | `expr1.gt_eq(expr2)` | -| lt | `expr1.lt(expr2)` | -| lt_eq | `expr1.lt_eq(expr2)` | -| not_eq | `expr1.not_eq(expr2)` | +| Syntax | Description | +| ----------- | --------------------- | +| x.eq(y) | Equal | +| x.not_eq(y) | Not Equal | +| x.gt(y) | Greater Than | +| x.gt_eq(y) | Greater Than or Equal | +| x.lt(y) | Less Than | +| x.lt_eq(y) | Less Than or Equal | + +:::{note} +Comparison operators (`<`, `<=`, `==`, `>=`, `>`) could be overloaded by the `PartialOrd` and `PartialEq` trait in Rust, +but these operators always return a `bool` which makes them not work with the expression API. +::: + +## Arithmetic Expressions + +| Syntax | Description | +| ---------------- | -------------- | +| x + y, x.add(y) | Addition | +| x - y, x.sub(y) | Subtraction | +| x \* y, x.mul(y) | Multiplication | +| x / y, x.div(y) | Division | +| x % y, x.rem(y) | Remainder | +| -x, x.neg() | Negation | ## Math Functions -In addition to the math functions listed here, some Rust operators are implemented for expressions, allowing -expressions such as `col("a") + col("b")` to be used. - -| Function | Notes | +| Syntax | Description | | --------------------- | ------------------------------------------------- | | abs(x) | absolute value | | acos(x) | inverse cosine | @@ -114,11 +148,10 @@ expressions such as `col("a") + col("b")` to be used. | tanh(x) | hyperbolic tangent | | trunc(x) | truncate toward zero | -### Math functions usage notes: - +:::{note} Unlike to some databases the math functions in Datafusion works the same way as Rust math functions, avoiding failing on corner cases e.g -``` +```sql ❯ select log(-1), log(0), sqrt(-1); +----------------+---------------+-----------------+ | log(Int64(-1)) | log(Int64(0)) | sqrt(Int64(-1)) | @@ -127,27 +160,19 @@ Unlike to some databases the math functions in Datafusion works the same way as +----------------+---------------+-----------------+ ``` -## Bitwise Operators - -| Operator | Notes | -| -------- | ----------------------------------------------- | -| & | Bitwise AND => `(expr1 & expr2)` | -| | | Bitwise OR => (expr1 | expr2) | -| # | Bitwise XOR => `(expr1 # expr2)` | -| << | Bitwise left shift => `(expr1 << expr2)` | -| >> | Bitwise right shift => `(expr1 << expr2)` | +::: ## Conditional Expressions -| Function | Notes | -| -------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| coalesce | Returns the first of its arguments that is not null. Null is returned only if all arguments are null. It is often used to substitute a default value for null values when data is retrieved for display. | -| case | CASE expression. The expression may chain multiple `when` expressions and end with an `end` or `otherwise` expression. Example:
case(col("a") % lit(3))
    .when(lit(0), lit("A"))
    .when(lit(1), lit("B"))
    .when(lit(2), lit("C"))
    .end()
or, end with `otherwise` to match any other conditions:
case(col("b").gt(lit(100)))
    .when(lit(true), lit("value > 100"))
    .otherwise(lit("value <= 100"))
| -| nullif | Returns a null value if `value1` equals `value2`; otherwise it returns `value1`. This can be used to perform the inverse operation of the `coalesce` expression. | +| Syntax | Description | +| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| coalesce([value, ...]) | Returns the first of its arguments that is not null. Null is returned only if all arguments are null. It is often used to substitute a default value for null values when data is retrieved for display. | +| case(expr)
    .when(expr)
    .end(),
case(expr)
    .when(expr)
    .otherwise(expr) | CASE expression. The expression may chain multiple `when` expressions and end with an `end` or `otherwise` expression. Example:
case(col("a") % lit(3))
    .when(lit(0), lit("A"))
    .when(lit(1), lit("B"))
    .when(lit(2), lit("C"))
    .end()
or, end with `otherwise` to match any other conditions:
case(col("b").gt(lit(100)))
    .when(lit(true), lit("value > 100"))
    .otherwise(lit("value <= 100"))
| +| nullif(value1, value2) | Returns a null value if `value1` equals `value2`; otherwise it returns `value1`. This can be used to perform the inverse operation of the `coalesce` expression. | ## String Expressions -| Function | Notes | +| Syntax | Description | | ---------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ascii(character) | Returns a numeric representation of the character (`character`). Example: `ascii('a') -> 97` | | bit_length(text) | Returns the length of the string (`text`) in bits. Example: `bit_length('spider') -> 48` | @@ -182,7 +207,7 @@ Unlike to some databases the math functions in Datafusion works the same way as ## Array Expressions -| Function | Notes | +| Syntax | Description | | ------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | array_append(array, element) | Appends an element to the end of an array. `array_append([1, 2, 3], 4) -> [1, 2, 3, 4]` | | array_concat(array[, ..., array_n]) | Concatenates arrays. `array_concat([1, 2, 3], [4, 5, 6]) -> [1, 2, 3, 4, 5, 6]` | @@ -190,10 +215,12 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_has_all(array, sub-array) | Returns true if all elements of sub-array exist in array `array_has_all([1,2,3], [1,3]) -> true` | | array_has_any(array, sub-array) | Returns true if any elements exist in both arrays `array_has_any([1,2,3], [1,4]) -> true` | | array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | +| array_distinct(array) | Returns distinct values from the array after removing duplicates. `array_distinct([1, 3, 2, 3, 1, 2, 4]) -> [1, 2, 3, 4]` | | array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` | | flatten(array) | Converts an array of arrays to a flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]` | | array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | | array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` | +| array_pop_front(array) | Returns the array without the first element. `array_pop_front([1, 2, 3]) -> [2, 3]` | | array_pop_back(array) | Returns the array without the last element. `array_pop_back([1, 2, 3]) -> [1, 2]` | | array_position(array, element) | Searches for an element in the array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2` | | array_positions(array, element) | Searches for an element in the array, returns all occurrences. `array_positions([1, 2, 2, 3, 4], 2) -> [2, 3]` | @@ -207,20 +234,24 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_replace_all(array, from, to) | Replaces all occurrences of the specified element with another specified element. `array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 5, 3, 5, 1, 4]` | | array_slice(array, index) | Returns a slice of the array. `array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6) -> [3, 4, 5, 6]` | | array_to_string(array, delimiter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` | +| array_intersect(array1, array2) | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | +| array_union(array1, array2) | Returns an array of the elements in the union of array1 and array2 without duplicates. `array_union([1, 2, 3, 4], [5, 6, 3, 4]) -> [1, 2, 3, 4, 5, 6]` | +| array_except(array1, array2) | Returns an array of the elements that appear in the first array but not in the second. `array_except([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | | cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | | make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | +| range(start [, stop, step]) | Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` | | trim_array(array, n) | Deprecated | ## Regular Expressions -| Function | Notes | +| Syntax | Description | | -------------- | ----------------------------------------------------------------------------- | | regexp_match | Matches a regular expression against a string and returns matched substrings. | | regexp_replace | Replaces strings that match a regular expression | ## Temporal Expressions -| Function | Notes | +| Syntax | Description | | -------------------- | ------------------------------------------------------ | | date_part | Extracts a subfield from the date. | | date_trunc | Truncates the date to a specified level of precision. | @@ -233,7 +264,7 @@ Unlike to some databases the math functions in Datafusion works the same way as ## Other Expressions -| Function | Notes | +| Syntax | Description | | ---------------------------- | ---------------------------------------------------------------------------------------------------------- | | array([value1, ...]) | Returns an array of fixed size with each argument (`[value1, ...]`) on it. | | in_list(expr, list, negated) | Returns `true` if (`expr`) belongs or not belongs (`negated`) to a list (`list`), otherwise returns false. | @@ -246,7 +277,7 @@ Unlike to some databases the math functions in Datafusion works the same way as ## Aggregate Functions -| Function | Notes | +| Syntax | Description | | ----------------------------------------------------------------- | --------------------------------------------------------------------------------------- | | avg(expr) | Сalculates the average value for `expr`. | | approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. | @@ -270,7 +301,7 @@ Unlike to some databases the math functions in Datafusion works the same way as ## Subquery Expressions -| Function | Notes | +| Syntax | Description | | --------------- | --------------------------------------------------------------------------------------------- | | exists | Creates an `EXISTS` subquery expression | | in_subquery | `df1.filter(in_subquery(col("foo"), df2))?` is the equivalent of the SQL `WHERE foo IN ` | @@ -280,7 +311,7 @@ Unlike to some databases the math functions in Datafusion works the same way as ## User-Defined Function Expressions -| Function | Notes | +| Syntax | Description | | ----------- | ------------------------------------------------------------------------- | | create_udf | Creates a new UDF with a specific signature and specific return type. | | create_udaf | Creates a new UDAF with a specific signature, state type and return type. | diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 272f534613cf..b737c3bab266 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -75,7 +75,7 @@ latency). Here are some example systems built using DataFusion: -- Specialized Analytical Database systems such as [CeresDB] and more general Apache Spark like system such a [Ballista]. +- Specialized Analytical Database systems such as [HoraeDB] and more general Apache Spark like system such a [Ballista]. - New query language engines such as [prql-query] and accelerators such as [VegaFusion] - Research platform for new Database Systems, such as [Flock] - SQL support to another library, such as [dask sql] @@ -96,7 +96,6 @@ Here are some active projects using DataFusion: - [Arroyo](https://github.com/ArroyoSystems/arroyo) Distributed stream processing engine in Rust - [Ballista](https://github.com/apache/arrow-ballista) Distributed SQL Query Engine -- [CeresDB](https://github.com/CeresDB/ceresdb) Distributed Time-Series Database - [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database - [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) - [Dask SQL](https://github.com/dask-contrib/dask-sql) Distributed SQL query engine in Python @@ -104,8 +103,11 @@ Here are some active projects using DataFusion: - [delta-rs](https://github.com/delta-io/delta-rs) Native Rust implementation of Delta Lake - [GreptimeDB](https://github.com/GreptimeTeam/greptimedb) Open Source & Cloud Native Distributed Time Series Database - [GlareDB](https://github.com/GlareDB/glaredb) Fast SQL database for querying and analyzing distributed data. +- [HoraeDB](https://github.com/apache/incubator-horaedb) Distributed Time-Series Database - [InfluxDB IOx](https://github.com/influxdata/influxdb_iox) Time Series Database - [Kamu](https://github.com/kamu-data/kamu-cli/) Planet-scale streaming data pipeline +- [LakeSoul](https://github.com/lakesoul-io/LakeSoul) Open source LakeHouse framework with native IO in Rust. +- [Lance](https://github.com/lancedb/lance) Modern columnar data format for ML - [Parseable](https://github.com/parseablehq/parseable) Log storage and observability platform - [qv](https://github.com/timvw/qv) Quickly view your data - [bdt](https://github.com/andygrove/bdt) Boring Data Tool @@ -126,7 +128,6 @@ Here are some less active projects that used DataFusion: [ballista]: https://github.com/apache/arrow-ballista [blaze]: https://github.com/blaze-init/blaze -[ceresdb]: https://github.com/CeresDB/ceresdb [cloudfuse buzz]: https://github.com/cloudfuse-io/buzz-rust [cnosdb]: https://github.com/cnosdb/cnosdb [cube store]: https://github.com/cube-js/cube.js/tree/master/rust @@ -136,6 +137,7 @@ Here are some less active projects that used DataFusion: [flock]: https://github.com/flock-lab/flock [kamu]: https://github.com/kamu-data/kamu-cli [greptime db]: https://github.com/GreptimeTeam/greptimedb +[horaedb]: https://github.com/apache/incubator-horaedb [influxdb iox]: https://github.com/influxdata/influxdb_iox [parseable]: https://github.com/parseablehq/parseable [prql-query]: https://github.com/prql/prql-query diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index d5717b9c2130..629a5f6ecb88 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -635,6 +635,10 @@ nullif(expression1, expression2) - [trim](#trim) - [upper](#upper) - [uuid](#uuid) +- [overlay](#overlay) +- [levenshtein](#levenshtein) +- [substr_index](#substr_index) +- [find_in_set](#find_in_set) ### `ascii` @@ -1120,6 +1124,67 @@ Returns UUID v4 string value which is unique per row. uuid() ``` +### `overlay` + +Returns the string which is replaced by another string from the specified position and specified count length. +For example, `overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas` + +``` +overlay(str PLACING substr FROM pos [FOR count]) +``` + +#### Arguments + +- **str**: String expression to operate on. +- **substr**: the string to replace part of str. +- **pos**: the start position to replace of str. +- **count**: the count of characters to be replaced from start position of str. If not specified, will use substr length instead. + +### `levenshtein` + +Returns the Levenshtein distance between the two given strings. +For example, `levenshtein('kitten', 'sitting') = 3` + +``` +levenshtein(str1, str2) +``` + +#### Arguments + +- **str1**: String expression to compute Levenshtein distance with str2. +- **str2**: String expression to compute Levenshtein distance with str1. + +### `substr_index` + +Returns the substring from str before count occurrences of the delimiter delim. +If count is positive, everything to the left of the final delimiter (counting from the left) is returned. +If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +For example, `substr_index('www.apache.org', '.', 1) = www`, `substr_index('www.apache.org', '.', -1) = org` + +``` +substr_index(str, delim, count) +``` + +#### Arguments + +- **str**: String expression to operate on. +- **delim**: the string to find in str to split str. +- **count**: The number of times to search for the delimiter. Can be both a positive or negative number. + +### `find_in_set` + +Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings. +For example, `find_in_set('b', 'a,b,c,d') = 2` + +``` +find_in_set(str, strlist) +``` + +#### Arguments + +- **str**: String expression to find in strlist. +- **strlist**: A string list is a string composed of substrings separated by , characters. + ## Binary String Functions - [decode](#decode) @@ -1215,9 +1280,11 @@ regexp_replace(str, regexp, replacement, flags) - [datepart](#datepart) - [extract](#extract) - [to_timestamp](#to_timestamp) +- [today](#today) - [to_timestamp_millis](#to_timestamp_millis) - [to_timestamp_micros](#to_timestamp_micros) - [to_timestamp_seconds](#to_timestamp_seconds) +- [to_timestamp_nanos](#to_timestamp_nanos) - [from_unixtime](#from_unixtime) ### `now` @@ -1242,6 +1309,14 @@ no matter when in the query plan the function executes. current_date() ``` +#### Aliases + +- today + +### `today` + +_Alias of [current_date](#current_date)._ + ### `current_time` Returns the current UTC time. @@ -1335,6 +1410,7 @@ date_part(part, expression) The following date parts are supported: - year + - quarter _(emits value in inclusive range [1, 4] based on which quartile of the year the date is in)_ - month - week _(week of the year)_ - day _(day of the month)_ @@ -1346,6 +1422,7 @@ date_part(part, expression) - nanosecond - dow _(day of the week)_ - doy _(day of the year)_ + - epoch _(seconds since Unix epoch)_ - **expression**: Time expression to operate on. Can be a constant, column, or function. @@ -1373,6 +1450,7 @@ extract(field FROM source) The following date fields are supported: - year + - quarter _(emits value in inclusive range [1, 4] based on which quartile of the year the date is in)_ - month - week _(week of the year)_ - day _(day of the month)_ @@ -1384,16 +1462,21 @@ extract(field FROM source) - nanosecond - dow _(day of the week)_ - doy _(day of the year)_ + - epoch _(seconds since Unix epoch)_ - **source**: Source time expression to operate on. Can be a constant, column, or function. ### `to_timestamp` -Converts a value to RFC3339 nanosecond timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). -Supports timestamp, integer, and unsigned integer types as input. -Integers and unsigned integers are parsed as Unix nanosecond timestamps and -return the corresponding RFC3339 nanosecond timestamp. +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). +Supports strings, integer, unsigned integer, and double types as input. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) +return the corresponding timestamp. + +Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. +Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. ``` to_timestamp(expression) @@ -1406,10 +1489,11 @@ to_timestamp(expression) ### `to_timestamp_millis` -Converts a value to RFC3339 millisecond timestamp format (`YYYY-MM-DDT00:00:00.000Z`). -Supports timestamp, integer, and unsigned integer types as input. -Integers and unsigned integers are parsed as Unix nanosecond timestamps and -return the corresponding RFC3339 timestamp. +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). +Supports strings, integer, and unsigned integer types as input. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`) +return the corresponding timestamp. ``` to_timestamp_millis(expression) @@ -1422,13 +1506,26 @@ to_timestamp_millis(expression) ### `to_timestamp_micros` -Converts a value to RFC3339 microsecond timestamp format (`YYYY-MM-DDT00:00:00.000000Z`). -Supports timestamp, integer, and unsigned integer types as input. -Integers and unsigned integers are parsed as Unix nanosecond timestamps and -return the corresponding RFC3339 timestamp. +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). +Supports strings, integer, and unsigned integer types as input. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) +return the corresponding timestamp. + +``` +to_timestamp_nanos(expression) +``` + +### `to_timestamp_nanos` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). +Supports strings, integer, and unsigned integer types as input. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) +return the corresponding timestamp. ``` -to_timestamp_micros(expression) +to_timestamp_nanos(expression) ``` #### Arguments @@ -1438,10 +1535,11 @@ to_timestamp_micros(expression) ### `to_timestamp_seconds` -Converts a value to RFC3339 second timestamp format (`YYYY-MM-DDT00:00:00Z`). -Supports timestamp, integer, and unsigned integer types as input. -Integers and unsigned integers are parsed as Unix nanosecond timestamps and -return the corresponding RFC3339 timestamp. +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). +Supports strings, integer, and unsigned integer types as input. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) +return the corresponding timestamp. ``` to_timestamp_seconds(expression) @@ -1455,8 +1553,8 @@ to_timestamp_seconds(expression) ### `from_unixtime` Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). -Input is parsed as a Unix nanosecond timestamp and returns the corresponding -RFC3339 timestamp. +Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) +return the corresponding timestamp. ``` from_unixtime(expression) @@ -1470,6 +1568,7 @@ from_unixtime(expression) ## Array Functions - [array_append](#array_append) +- [array_sort](#array_sort) - [array_cat](#array_cat) - [array_concat](#array_concat) - [array_contains](#array_contains) @@ -1481,6 +1580,7 @@ from_unixtime(expression) - [array_length](#array_length) - [array_ndims](#array_ndims) - [array_prepend](#array_prepend) +- [array_pop_front](#array_pop_front) - [array_pop_back](#array_pop_back) - [array_position](#array_position) - [array_positions](#array_positions) @@ -1498,6 +1598,7 @@ from_unixtime(expression) - [cardinality](#cardinality) - [empty](#empty) - [list_append](#list_append) +- [list_sort](#list_sort) - [list_cat](#list_cat) - [list_concat](#list_concat) - [list_dims](#list_dims) @@ -1526,6 +1627,7 @@ from_unixtime(expression) - [string_to_array](#string_to_array) - [string_to_list](#string_to_list) - [trim_array](#trim_array) +- [range](#range) ### `array_append` @@ -1558,6 +1660,36 @@ array_append(array, element) - list_append - list_push_back +### `array_sort` + +Sort array. + +``` +array_sort(array, desc, nulls_first) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **desc**: Whether to sort in descending order(`ASC` or `DESC`). +- **nulls_first**: Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`). + +#### Example + +``` +❯ select array_sort([3, 1, 2]); ++-----------------------------+ +| array_sort(List([3,1,2])) | ++-----------------------------+ +| [1, 2, 3] | ++-----------------------------+ +``` + +#### Aliases + +- list_sort + ### `array_cat` _Alias of [array_concat](#array_concat)._ @@ -1833,6 +1965,30 @@ array_prepend(element, array) - list_prepend - list_push_front +### `array_pop_front` + +Returns the array without the first element. + +``` +array_pop_first(array) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select array_pop_first([1, 2, 3]); ++-------------------------------+ +| array_pop_first(List([1,2,3])) | ++-------------------------------+ +| [2, 3] | ++-------------------------------+ +``` + ### `array_pop_back` Returns the array without the last element. @@ -2194,6 +2350,82 @@ array_to_string(array, delimiter) - list_join - list_to_string +### `array_union` + +Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates. + +``` +array_union(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select array_union([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6] | ++----------------------------------------------------+ +❯ select array_union([1, 2, 3, 4], [5, 6, 7, 8]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 7, 8]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6, 7, 8] | ++----------------------------------------------------+ +``` + +--- + +#### Aliases + +- list_union + +### `array_except` + +Returns an array of the elements that appear in the first array but not in the second. + +``` +array_except(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select array_except([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +❯ select array_except([1, 2, 3, 4], [3, 4, 5, 6]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [3, 4, 5, 6]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +``` + +--- + +#### Aliases + +- list_except + ### `cardinality` Returns the total number of elements in the array. @@ -2246,6 +2478,10 @@ empty(array) _Alias of [array_append](#array_append)._ +### `list_sort` + +_Alias of [array_sort](#array_sort)._ + ### `list_cat` _Alias of [array_concat](#array_concat)._ @@ -2409,6 +2645,20 @@ trim_array(array, n) Can be a constant, column, or function, and any combination of array operators. - **n**: Element to trim the array. +### `range` + +Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` + +The range start..end contains all values with start <= x < end. It is empty if start >= end. + +Step can not be 0 (then the range will be nonsense.). + +#### Arguments + +- **start**: start of the range +- **end**: end of the range (not included) +- **step**: increase by step (can not be 0) + ## Struct Functions - [struct](#struct) diff --git a/docs/source/user-guide/sql/write_options.md b/docs/source/user-guide/sql/write_options.md index c98a39f24b92..470591afafff 100644 --- a/docs/source/user-guide/sql/write_options.md +++ b/docs/source/user-guide/sql/write_options.md @@ -42,12 +42,11 @@ WITH HEADER ROW DELIMITER ';' LOCATION '/test/location/my_csv_table/' OPTIONS( -CREATE_LOCAL_PATH 'true', NULL_VALUE 'NAN' ); ``` -When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (gzip compression, special delimiter, and header row included). Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified the `OPTIONS` setting will be ignored. CREATE_LOCAL_PATH is a special option that indicates if DataFusion should create local file paths when writing new files if they do not already exist. This option is useful if you wish to create an external table from scratch, using only DataFusion SQL statements. Finally, NULL_VALUE is a CSV format specific option that determines how null values should be encoded within the CSV file. +When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (gzip compression, special delimiter, and header row included). Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified the `OPTIONS` setting will be ignored. NULL_VALUE is a CSV format specific option that determines how null values should be encoded within the CSV file. Finally, options can be passed when running a `COPY` command. @@ -70,19 +69,9 @@ In this example, we write the entirety of `source_table` out to a folder of parq The following special options are specific to the `COPY` command. | Option | Description | Default Value | -| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | +| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | --- | | SINGLE_FILE_OUTPUT | If true, COPY query will write output to a single file. Otherwise, multiple files will be written to a directory in parallel. | true | -| FORMAT | Specifies the file format COPY query will write out. If single_file_output is false or the format cannot be inferred from the file extension, then FORMAT must be specified. | N/A | - -### CREATE EXTERNAL TABLE Specific Options - -The following special options are specific to creating an external table. - -| Option | Description | Default Value | -| ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------- | -| SINGLE_FILE | If true, indicates that this external table is backed by a single file. INSERT INTO queries will append to this file. | false | -| CREATE_LOCAL_PATH | If true, the folder or file backing this table will be created on the local file system if it does not already exist when running INSERT INTO queries. | false | -| INSERT_MODE | Determines if INSERT INTO queries should append to existing files or append new files to an existing directory. Valid values are append_to_file, append_new_files, and error. Note that "error" will block inserting data into this table. | CSV and JSON default to append_to_file. Parquet defaults to append_new_files | +| FORMAT | Specifies the file format COPY query will write out. If single_file_output is false or the format cannot be inferred from the file extension, then FORMAT must be specified. | N/A | | ### JSON Format Specific Options @@ -92,7 +81,7 @@ The following options are available when writing JSON files. Note: If any unsupp | ----------- | ---------------------------------------------------------------------------------------------------------------------------------- | ------------- | | COMPRESSION | Sets the compression that should be applied to the entire JSON file. Supported values are GZIP, BZIP2, XZ, ZSTD, and UNCOMPRESSED. | UNCOMPRESSED | -### CSV Format Sepcific Options +### CSV Format Specific Options The following options are available when writing CSV files. Note: if any unsupported options is specified an error will be raised and the query will fail. diff --git a/datafusion/expr/src/struct_expressions.rs b/docs/src/lib.rs similarity index 65% rename from datafusion/expr/src/struct_expressions.rs rename to docs/src/lib.rs index bbfcac0e2396..f73132468ec9 100644 --- a/datafusion/expr/src/struct_expressions.rs +++ b/docs/src/lib.rs @@ -15,21 +15,5 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::DataType; - -/// Currently supported types by the struct function. -pub static SUPPORTED_STRUCT_TYPES: &[DataType] = &[ - DataType::Boolean, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - DataType::Utf8, - DataType::LargeUtf8, -]; +#[cfg(test)] +mod library_logical_plan; diff --git a/docs/src/library_logical_plan.rs b/docs/src/library_logical_plan.rs new file mode 100644 index 000000000000..355003941570 --- /dev/null +++ b/docs/src/library_logical_plan.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::error::Result; +use datafusion::logical_expr::builder::LogicalTableSource; +use datafusion::logical_expr::{Filter, LogicalPlan, LogicalPlanBuilder, TableScan}; +use datafusion::prelude::*; +use std::sync::Arc; + +#[test] +fn plan_1() -> Result<()> { + // create a logical table source + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ]); + let table_source = LogicalTableSource::new(SchemaRef::new(schema)); + + // create a TableScan plan + let projection = None; // optional projection + let filters = vec![]; // optional filters to push down + let fetch = None; // optional LIMIT + let table_scan = LogicalPlan::TableScan(TableScan::try_new( + "person", + Arc::new(table_source), + projection, + filters, + fetch, + )?); + + // create a Filter plan that evaluates `id > 500` and wraps the TableScan + let filter_expr = col("id").gt(lit(500)); + let plan = LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(table_scan))?); + + // print the plan + println!("{}", plan.display_indent_schema()); + + Ok(()) +} + +#[test] +fn plan_builder_1() -> Result<()> { + // create a logical table source + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ]); + let table_source = LogicalTableSource::new(SchemaRef::new(schema)); + + // optional projection + let projection = None; + + // create a LogicalPlanBuilder for a table scan + let builder = LogicalPlanBuilder::scan("person", Arc::new(table_source), projection)?; + + // perform a filter that evaluates `id > 500`, and build the plan + let plan = builder.filter(col("id").gt(lit(500)))?.build()?; + + // print the plan + println!("{}", plan.display_indent_schema()); + + Ok(()) +} diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 5ab10e42cf68..b9c4db17c098 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -26,4 +26,4 @@ edition = { workspace = true } arrow = { workspace = true } datafusion-common = { path = "../datafusion/common" } env_logger = "0.10.0" -rand = "0.8" +rand = { workspace = true } diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index dfd878275181..0c3668d2f8c0 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -38,7 +38,7 @@ pub fn batches_to_vec(batches: &[RecordBatch]) -> Vec> { .collect() } -/// extract values from batches and sort them +/// extract i32 values from batches and sort them pub fn partitions_to_sorted_vec(partitions: &[Vec]) -> Vec> { let mut values: Vec<_> = partitions .iter() @@ -70,13 +70,23 @@ pub fn add_empty_batches( } /// "stagger" batches: split the batches into random sized batches +/// +/// For example, if the input batch has 1000 rows, [`stagger_batch`] might return +/// multiple batches +/// ```text +/// [ +/// RecordBatch(123 rows), +/// RecordBatch(234 rows), +/// RecordBatch(634 rows), +/// ] +/// ``` pub fn stagger_batch(batch: RecordBatch) -> Vec { let seed = 42; stagger_batch_with_seed(batch, seed) } -/// "stagger" batches: split the batches into random sized batches -/// using the specified value for a rng seed +/// "stagger" batches: split the batches into random sized batches using the +/// specified value for a rng seed. See [`stagger_batch`] for more detail. pub fn stagger_batch_with_seed(batch: RecordBatch, seed: u64) -> Vec { let mut batches = vec![]; diff --git a/testing b/testing index 37f29510ce97..98fceecd024d 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 37f29510ce97cd491b8e6ed75866c6533a5ea2a1 +Subproject commit 98fceecd024dccd2f8a00e32fc144975f218acf4